Skip to content

Commit

Permalink
fix: make sure to return the caller-provided sql with TypedQuery
Browse files Browse the repository at this point in the history
  • Loading branch information
jacek-prisma committed Dec 11, 2024
1 parent 153c109 commit 17c1994
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 99 deletions.
78 changes: 44 additions & 34 deletions quaint/src/connector/postgres/native/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ use tokio_postgres::{Client, Error, Statement};

use crate::connector::metrics::strip_query_traceparent;

use super::query::{PreparedQuery, TypedQuery};
use super::query::{PreparedQuery, QueryMetadata, TypedQuery};

/// Types that can be used as a cache for prepared queries and statements.
#[async_trait]
pub trait QueryCache: From<CacheSettings> + Send + Sync {
/// The type that is returned when a prepared query is requested from the cache.
type Query: PreparedQuery;
type Query<'a>: PreparedQuery;

/// Retrieve a prepared query.
async fn get_query(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Self::Query, Error>;
async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result<Self::Query<'a>, Error>;

/// Retrieve a prepared statement.
///
Expand All @@ -36,10 +36,10 @@ pub struct NoOpCache;

#[async_trait]
impl QueryCache for NoOpCache {
type Query = Statement;
type Query<'a> = Statement;

#[inline]
async fn get_query(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Statement, Error> {
async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result<Statement, Error> {
self.get_statement(client, sql, types).await
}

Expand All @@ -55,7 +55,7 @@ impl From<CacheSettings> for NoOpCache {
}
}

/// An LRU cache that creates a prepared statement for every newly requested query.
/// An LRU cache that creates a prepared statement for every query that is not in the cache.
#[derive(Debug)]
pub struct PreparedStatementLruCache {
cache: InnerLruCache<Statement>,
Expand All @@ -71,10 +71,10 @@ impl PreparedStatementLruCache {

#[async_trait]
impl QueryCache for PreparedStatementLruCache {
type Query = Statement;
type Query<'a> = Statement;

#[inline]
async fn get_query(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Statement, Error> {
async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result<Statement, Error> {
self.get_statement(client, sql, types).await
}

Expand All @@ -96,15 +96,15 @@ impl From<CacheSettings> for PreparedStatementLruCache {
}
}

/// An LRU cache that creates and stores type information relevant to queries as instances of
/// [`TypedQuery`]. Queries are identified by their content with tracing information removed
/// (which makes it possible to cache them at all). The caching behavior is implemented in
/// [`get_query`](Self::get_query), while statements returned by
/// An LRU cache that creates and stores query type information rather than prepared statements.
/// Queries are identified by their content with tracing information removed (which makes it
/// possible to cache them at all) and returned as instances of [`TypedQuery`]. The caching
/// behavior is implemented in [`get_query`](Self::get_query), while statements returned from
/// [`get_statement`](Self::get_statement) are always freshly prepared, because statements cannot
/// be re-used when tracing information is present.
#[derive(Debug)]
pub struct TracingLruCache {
cache: InnerLruCache<Arc<TypedQuery>>,
cache: InnerLruCache<Arc<QueryMetadata>>,
}

impl TracingLruCache {
Expand All @@ -117,20 +117,21 @@ impl TracingLruCache {

#[async_trait]
impl QueryCache for TracingLruCache {
type Query = Arc<TypedQuery>;
type Query<'a> = TypedQuery<'a>;

async fn get_query(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Arc<TypedQuery>, Error> {
async fn get_query<'a>(&self, client: &Client, sql: &'a str, types: &[Type]) -> Result<TypedQuery<'a>, Error> {
let sql_without_traceparent = strip_query_traceparent(sql);

match self.cache.get(sql_without_traceparent, types).await {
Some(query) => Ok(query),
let metadata = match self.cache.get(sql_without_traceparent, types).await {
Some(metadata) => metadata,
None => {
let stmt = client.prepare_typed(sql, types).await?;
let query = Arc::new(TypedQuery::from_statement(sql, &stmt));
self.cache.insert(sql_without_traceparent, types, query.clone()).await;
Ok(query)
let stmt = client.prepare_typed(sql_without_traceparent, types).await?;
let metdata = Arc::new(QueryMetadata::from(&stmt));
self.cache.insert(sql_without_traceparent, types, metdata.clone()).await;
metdata
}
}
};
Ok(TypedQuery::from_sql_and_metadata(sql, metadata))
}

async fn get_statement(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Statement, Error> {
Expand Down Expand Up @@ -310,18 +311,21 @@ mod tests {
let sql = "SELECT $1";
let types = [Type::INT4];

let stmt1 = cache.get_query(&client, sql, &types).await.unwrap();
let stmt2 = cache.get_query(&client, sql, &types).await.unwrap();
assert!(Arc::ptr_eq(&stmt1, &stmt2), "stmt1 and stmt2 should be the same Arc");
let q1 = cache.get_query(&client, sql, &types).await.unwrap();
let q2 = cache.get_query(&client, sql, &types).await.unwrap();
assert!(
std::ptr::eq(q1.metadata(), q2.metadata()),
"stmt1 and stmt2 should re-use the same metadata"
);

// replace our cached query with a new one going over the capacity
cache.get_query(&client, sql, &[Type::INT8]).await.unwrap();

// the old query should be evicted from the cache
let stmt3 = cache.get_query(&client, sql, &types).await.unwrap();
let q3 = cache.get_query(&client, sql, &types).await.unwrap();
assert!(
!Arc::ptr_eq(&stmt1, &stmt3),
"stmt1 and stmt3 should not be the same Arc"
!std::ptr::eq(q1.metadata(), q3.metadata()),
"stmt1 and stmt3 should not re-use the same metadata"
);
})
.await;
Expand All @@ -335,9 +339,15 @@ mod tests {
let sql2 = "SELECT $1 /* traceparent=00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-02 */";
let types = [Type::INT4];

let stmt1 = cache.get_query(&client, sql1, &types).await.unwrap();
let stmt2 = cache.get_query(&client, sql2, &types).await.unwrap();
assert!(Arc::ptr_eq(&stmt1, &stmt2), "stmt1 and stmt2 should be the same Arc");
let q1 = cache.get_query(&client, sql1, &types).await.unwrap();
assert_eq!(q1.query(), sql1);
let q2 = cache.get_query(&client, sql2, &types).await.unwrap();
assert_eq!(q2.query(), sql2);

assert!(
std::ptr::eq(q1.metadata(), q2.metadata()),
"stmt1 and stmt2 should re-use the same metadata"
);
})
.await;
}
Expand All @@ -349,9 +359,9 @@ mod tests {
let sql = "SELECT $1";
let types = [Type::INT4];

let stmt1 = cache.get_statement(&client, sql, &types).await.unwrap();
let stmt2 = cache.get_statement(&client, sql, &types).await.unwrap();
assert_ne!(stmt1.name(), stmt2.name());
let q1 = cache.get_statement(&client, sql, &types).await.unwrap();
let q2 = cache.get_statement(&client, sql, &types).await.unwrap();
assert_ne!(q1.name(), q2.name());
})
.await;
}
Expand Down
102 changes: 37 additions & 65 deletions quaint/src/connector/postgres/native/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,76 +45,70 @@ impl PreparedQuery for Statement {

/// A query combined with the relevant type information about its parameters and columns.
#[derive(Debug)]
pub struct TypedQuery {
sql: String,
param_types: Vec<Type>,
column_names: Vec<String>,
column_types: Vec<Type>,
pub struct TypedQuery<'a> {
sql: &'a str,
metadata: Arc<QueryMetadata>,
}

impl TypedQuery {
impl<'a> TypedQuery<'a> {
/// Create a new typed query from a SQL string and a statement.
pub fn from_statement(sql: impl Into<String>, statement: &Statement) -> Self {
pub fn from_sql_and_metadata(sql: &'a str, metadata: impl Into<Arc<QueryMetadata>>) -> Self {
Self {
sql: sql.into(),
param_types: statement.params().to_vec(),
column_names: statement.columns().iter().map(|c| c.name().to_owned()).collect(),
column_types: statement.columns().iter().map(|c| c.type_().clone()).collect(),
sql,
metadata: metadata.into(),
}
}
}

#[async_trait]
impl PreparedQuery for TypedQuery {
fn param_types(&self) -> impl ExactSizeIterator<Item = &Type> + '_ {
self.param_types.iter()
}

fn column_names(&self) -> impl ExactSizeIterator<Item = &str> + '_ {
self.column_names.iter().map(|s| s.as_str())
}

fn column_types(&self) -> impl ExactSizeIterator<Item = &Type> + '_ {
self.column_types.iter()
/// Get the SQL string of the query.
pub fn query(&self) -> &'a str {
self.sql
}

async fn dispatch<Args>(&self, client: &Client, args: Args) -> Result<RowStream, Error>
where
Args: IntoIterator + Send,
Args::Item: BorrowToSql,
Args::IntoIter: ExactSizeIterator + Send,
{
client
.query_typed_raw(&self.sql, args.into_iter().zip(self.param_types.iter().cloned()))
.await
/// Get the metadata associated with the query.
pub fn metadata(&self) -> &QueryMetadata {
&self.metadata
}
}

#[async_trait]
impl<A: PreparedQuery + Sync> PreparedQuery for Arc<A> {
#[inline]
impl<'a> PreparedQuery for TypedQuery<'a> {
fn param_types(&self) -> impl ExactSizeIterator<Item = &Type> + '_ {
self.as_ref().param_types()
self.metadata.param_types.iter()
}

#[inline]
fn column_names(&self) -> impl ExactSizeIterator<Item = &str> + '_ {
self.as_ref().column_names()
self.metadata.column_names.iter().map(|s| s.as_str())
}

#[inline]
fn column_types(&self) -> impl ExactSizeIterator<Item = &Type> + '_ {
self.as_ref().column_types()
self.metadata.column_types.iter()
}

#[inline]
async fn dispatch<Args>(&self, client: &Client, args: Args) -> Result<RowStream, Error>
where
Args: IntoIterator + Send,
Args::Item: BorrowToSql,
Args::IntoIter: ExactSizeIterator + Send,
{
self.as_ref().dispatch(client, args).await
let typed_args = args.into_iter().zip(self.metadata.param_types.iter().cloned());
client.query_typed_raw(self.sql, typed_args).await
}
}

#[derive(Debug)]
pub struct QueryMetadata {
param_types: Vec<Type>,
column_names: Vec<String>,
column_types: Vec<Type>,
}

impl From<&Statement> for QueryMetadata {
fn from(statement: &Statement) -> Self {
Self {
param_types: statement.params().to_vec(),
column_names: statement.columns().iter().map(|c| c.name().to_owned()).collect(),
column_types: statement.columns().iter().map(|c| c.type_().clone()).collect(),
}
}
}

Expand All @@ -136,7 +130,7 @@ mod tests {
run_with_client(|client| async move {
let query = "SELECT $1";
let stmt = client.prepare_typed(query, &[Type::INT4]).await.unwrap();
let typed = TypedQuery::from_statement(query, &stmt);
let typed = TypedQuery::from_sql_and_metadata(query, QueryMetadata::from(&stmt));

assert_eq!(typed.param_types().cloned().collect::<Vec<_>>(), stmt.params());
assert_eq!(
Expand Down Expand Up @@ -176,28 +170,6 @@ mod tests {
.await;
}

#[tokio::test]
async fn arc_trait_methods_match_statement_and_dispatch() {
run_with_client(|client| async move {
let query = "SELECT $1";
let stmt = Arc::new(client.prepare_typed(query, &[Type::INT4]).await.unwrap());

assert_eq!(stmt.param_types().cloned().collect::<Vec<_>>(), stmt.params());
assert_eq!(
stmt.column_names().collect::<Vec<_>>(),
stmt.columns().iter().map(|c| c.name()).collect::<Vec<_>>()
);
assert_eq!(
stmt.column_types().collect::<Vec<_>>(),
stmt.columns().iter().map(|c| c.type_()).collect::<Vec<_>>()
);

let result = stmt.dispatch(&client, &[&1i32]).await;
assert!(result.is_ok(), "{:?}", result.err());
})
.await;
}

async fn run_with_client<Func, Fut>(test: Func)
where
Func: FnOnce(Client) -> Fut,
Expand Down

0 comments on commit 17c1994

Please sign in to comment.