From 8891fb79305675d89745cb5d04f38bbc72401d2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20Dub=C3=A9?= Date: Tue, 24 Sep 2024 13:31:57 +0000 Subject: [PATCH] Server side parameters via with_param Fixes #142 --- src/lib.rs | 13 ++++++++++++- src/query.rs | 12 +++++++++++- tests/it/query.rs | 25 +++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 6df4a50..8d5c16f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,7 +5,8 @@ #[macro_use] extern crate static_assertions; -use self::{error::Result, http_client::HttpClient}; +use self::{error::Result, http_client::HttpClient, sql::Bind}; +use ::serde::Serialize; use std::{collections::HashMap, fmt::Display, sync::Arc}; pub use self::{compression::Compression, row::Row}; @@ -160,6 +161,16 @@ impl Client { self } + pub fn with_param( + self, + name: impl Into, + value: impl Bind + Serialize, + ) -> Result { + let mut param = String::from("param_"); + Bind::write(&value, &mut param)?; + Ok(self.with_option(name, param)) + } + /// Used to specify a header that will be passed to all queries. /// /// # Example diff --git a/src/query.rs b/src/query.rs index 926fdb5..b9f998f 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,5 +1,5 @@ use hyper::{header::CONTENT_LENGTH, Method, Request}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use std::fmt::Display; use url::Url; @@ -195,6 +195,16 @@ impl Query { self.client.add_option(name, value); self } + + pub fn with_param( + self, + name: impl Into, + value: impl Bind + Serialize, + ) -> Result { + let mut param = String::from("param_"); + Bind::write(&value, &mut param)?; + Ok(self.with_option(name, param)) + } } /// A cursor that emits rows. diff --git a/tests/it/query.rs b/tests/it/query.rs index 80e0158..b747495 100644 --- a/tests/it/query.rs +++ b/tests/it/query.rs @@ -85,6 +85,31 @@ async fn fetch_one_and_optional() { assert_eq!(got_string, "bar"); } +#[tokio::test] +async fn server_side_param() { + let client = prepare_database!() + .with_param("val1", 42) + .expect("failed to bind 42"); + + let result = client + .query("SELECT plus({val1: Int32}, {val2: String}) AS result") + .with_param("val2", 144) + .expect("failed to bind 144") + .fetch_one::() + .await + .expect("failed to fetch u64"); + assert_eq!(result, 186); + + let result = client + .query("SELECT {val1: String} AS result") + .with_param("param_val1", "string") + .expect("failed to bind \"string\"") + .fetch_one::() + .await + .expect("failed to fetch string"); + assert_eq!(result, "string"); +} + // See #19. #[tokio::test] async fn long_query() {