Skip to content

Commit

Permalink
Server side parameters via with_param
Browse files Browse the repository at this point in the history
Fixes #142
  • Loading branch information
serprex committed Nov 7, 2024
1 parent 0290fba commit 1fba9d9
Show file tree
Hide file tree
Showing 7 changed files with 262 additions and 46 deletions.
10 changes: 9 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#[macro_use]
extern crate static_assertions;

use self::{error::Result, http_client::HttpClient};
use self::{error::Result, http_client::HttpClient, sql::ser};
use ::serde::Serialize;
use std::{collections::HashMap, fmt::Display, sync::Arc};

pub use self::{compression::Compression, row::Row};
Expand Down Expand Up @@ -160,6 +161,13 @@ impl Client {
self
}

/// Specify server side parameter for all this client's queries.
pub fn with_param(self, name: &str, value: impl Serialize) -> Result<Self, String> {
let mut param = String::from("");
ser::write_param(&mut param, &value)?;
Ok(self.with_option(format!("param_{name}"), param))
}

/// Used to specify a header that will be passed to all queries.
///
/// # Example
Expand Down
17 changes: 15 additions & 2 deletions src/query.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use hyper::{header::CONTENT_LENGTH, Method, Request};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use url::Url;

Expand All @@ -9,7 +9,7 @@ use crate::{
request_body::RequestBody,
response::Response,
row::Row,
sql::{Bind, SqlBuilder},
sql::{ser, Bind, SqlBuilder},
Client,
};

Expand Down Expand Up @@ -196,4 +196,17 @@ impl Query {
self.client.add_option(name, value);
self
}

/// Specify server side parameter for query.
///
/// In queries you can reference params as {name: type} e.g. {val: Int32}.
pub fn with_param(mut self, name: &str, value: impl Serialize) -> Self {
let mut param = String::from("");
if let Err(err) = ser::write_param(&mut param, &value) {
self.sql = SqlBuilder::Failed(format!("invalid param: {err}"));
self
} else {
self.with_option(format!("param_{name}"), param)
}
}
}
6 changes: 3 additions & 3 deletions src/sql/bind.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@ use super::{escape, ser};
#[sealed]
pub trait Bind {
#[doc(hidden)]
fn write(&self, dst: impl fmt::Write) -> Result<(), String>;
fn write(&self, dst: &mut impl fmt::Write) -> Result<(), String>;
}

#[sealed]
impl<S: Serialize> Bind for S {
#[inline]
fn write(&self, mut dst: impl fmt::Write) -> Result<(), String> {
fn write(&self, mut dst: &mut impl fmt::Write) -> Result<(), String> {
ser::write_arg(&mut dst, self)
}
}
Expand All @@ -26,7 +26,7 @@ pub struct Identifier<'a>(pub &'a str);
#[sealed]
impl<'a> Bind for Identifier<'a> {
#[inline]
fn write(&self, dst: impl fmt::Write) -> Result<(), String> {
fn write(&self, dst: &mut impl fmt::Write) -> Result<(), String> {
escape::identifier(self.0, dst).map_err(|err| err.to_string())
}
}
43 changes: 20 additions & 23 deletions src/sql/escape.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,32 @@
use std::fmt;

// Trust clickhouse-connect https://github.com/ClickHouse/clickhouse-connect/blob/5d85563410f3ec378cb199ec51d75e033211392c/clickhouse_connect/driver/binding.py#L15

// See https://clickhouse.tech/docs/en/sql-reference/syntax/#syntax-string-literal
pub(crate) fn string(src: &str, dst: impl fmt::Write) -> fmt::Result {
escape(src, dst, '\'')
pub(crate) fn string(src: &str, dst: &mut impl fmt::Write) -> fmt::Result {
dst.write_char('\'')?;
escape(src, dst)?;
dst.write_char('\'')
}

// See https://clickhouse.tech/docs/en/sql-reference/syntax/#syntax-identifiers
pub(crate) fn identifier(src: &str, dst: impl fmt::Write) -> fmt::Result {
escape(src, dst, '`')
pub(crate) fn identifier(src: &str, dst: &mut impl fmt::Write) -> fmt::Result {
dst.write_char('`')?;
escape(src, dst)?;
dst.write_char('`')
}

fn escape(src: &str, mut dst: impl fmt::Write, ch: char) -> fmt::Result {
dst.write_char(ch)?;

// TODO: escape newlines?
for (idx, part) in src.split(ch).enumerate() {
if idx > 0 {
dst.write_char('\\')?;
dst.write_char(ch)?;
}

for (idx, part) in part.split('\\').enumerate() {
if idx > 0 {
dst.write_str("\\\\")?;
}

dst.write_str(part)?;
}
pub(crate) fn escape(src: &str, dst: &mut impl fmt::Write) -> fmt::Result {
const REPLACE: &[char] = &['\\', '\'', '`', '\t', '\n'];
let mut rest = src;
while let Some(nextidx) = rest.find(REPLACE) {
let (before, after) = rest.split_at(nextidx);
rest = &after[1..];
dst.write_str(before)?;
dst.write_char('\\')?;
dst.write_str(&after[..1])?;
}

dst.write_char(ch)
dst.write_str(rest)
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub use bind::{Bind, Identifier};

mod bind;
pub(crate) mod escape;
mod ser;
pub(crate) mod ser;

#[derive(Debug, Clone)]
pub(crate) enum SqlBuilder {
Expand Down
Loading

0 comments on commit 1fba9d9

Please sign in to comment.