Skip to content

Commit

Permalink
test: cover the queries and caches
Browse files Browse the repository at this point in the history
  • Loading branch information
jacek-prisma committed Dec 11, 2024
1 parent d40b7ed commit 3e64697
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 21 deletions.
128 changes: 116 additions & 12 deletions quaint/src/connector/postgres/native/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub trait QueryCache: From<CacheSettings> + Send + Sync {
/// The type of query that is returned by the cache.
type Query: PreparedQuery;

/// Retrieves a query from the cache or prepares and caches it if it's not present.
/// Retrieve a prepared query from the cache or prepare and cache one if it's not present.
async fn get_by_query(&self, client: &Client, sql: &str, types: &[Type]) -> Result<Self::Query, Error>;
}

Expand Down Expand Up @@ -109,12 +109,7 @@ impl QueryCache for LruTracingCache {
Some(query) => Ok(query),
None => {
let stmt = client.prepare_typed(sql, types).await?;
let query = Arc::new(TypedQuery {
sql: sql.into(),
param_types: stmt.params().to_vec(),
column_names: stmt.columns().iter().map(|c| c.name().to_owned()).collect(),
column_types: stmt.columns().iter().map(|c| c.type_().clone()).collect(),
});
let query = Arc::new(TypedQuery::from_statement(sql, &stmt));
self.cache.insert(sql_without_traceparent, types, query.clone()).await;
Ok(query)
}
Expand All @@ -136,16 +131,15 @@ pub struct CacheSettings {

/// Key uniquely representing an SQL statement in the prepared statements cache.
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct QueryKey {
struct QueryKey {
/// Hash of a string with SQL query.
sql: u64,
/// Combined hash of types for all parameters from the query.
types_hash: u64,
}

impl QueryKey {
fn new(sql: &str, params: &[Type]) -> Self {
let st = RandomState::new();
fn new<S: BuildHasher>(st: &S, sql: &str, params: &[Type]) -> Self {
Self {
sql: st.hash_one(sql),
types_hash: st.hash_one(params),
Expand All @@ -156,12 +150,14 @@ impl QueryKey {
#[derive(Debug)]
struct InnerLruCache<V> {
cache: Mutex<LruCache<QueryKey, V>>,
state: RandomState,
}

impl<V> InnerLruCache<V> {
fn with_capacity(capacity: usize) -> Self {
Self {
cache: Mutex::new(LruCache::new(capacity)),
state: RandomState::new(),
}
}

Expand All @@ -173,7 +169,7 @@ impl<V> InnerLruCache<V> {
let capacity = cache.capacity();
let stored = cache.len();

let key = QueryKey::new(sql, types);
let key = QueryKey::new(&self.state, sql, types);
match cache.get_mut(&key) {
Some(value) => {
tracing::trace!(
Expand All @@ -197,6 +193,114 @@ impl<V> InnerLruCache<V> {
}

pub async fn insert(&self, sql: &str, types: &[Type], value: V) {
self.cache.lock().await.insert(QueryKey::new(sql, types), value);
let key = QueryKey::new(&self.state, sql, types);
self.cache.lock().await.insert(key, value);
}
}

#[cfg(test)]
mod tests {
use super::*;

use std::future::Future;

pub(crate) use crate::connector::postgres::url::PostgresNativeUrl;
use crate::{
connector::{MakeTlsConnectorManager, PostgresFlavour},
tests::test_api::postgres::CONN_STR,
};
use url::Url;

#[tokio::test]
async fn noop_prepared_statement_cache_prepares_new_statements_every_time() {
run_with_client(|client| async move {
let cache = NoopPreparedStatementCache;
let sql = "SELECT $1";
let types = [Type::INT4];

let stmt1 = cache.get_by_query(&client, sql, &types).await.unwrap();
let stmt2 = cache.get_by_query(&client, sql, &types).await.unwrap();
assert_ne!(stmt1.name(), stmt2.name());
})
.await;
}

#[tokio::test]
async fn lru_prepared_statement_cache_reuses_statements_within_capacity() {
run_with_client(|client| async move {
let cache = LruPreparedStatementCache::with_capacity(1);
let sql = "SELECT $1";
let types = [Type::INT4];

let stmt1 = cache.get_by_query(&client, sql, &types).await.unwrap();
let stmt2 = cache.get_by_query(&client, sql, &types).await.unwrap();
assert_eq!(stmt1.name(), stmt2.name());

// replace our cached statement with a new one going over the capacity
cache.get_by_query(&client, sql, &[Type::INT8]).await.unwrap();

// the old statement should be evicted from the cache
let stmt3 = cache.get_by_query(&client, sql, &types).await.unwrap();
assert_ne!(stmt1.name(), stmt3.name());
})
.await;
}

#[tokio::test]
async fn tracing_cache_reuses_queries_within_capacity() {
run_with_client(|client| async move {
let cache = LruTracingCache::with_capacity(1);
let sql = "SELECT $1";
let types = [Type::INT4];

let stmt1 = cache.get_by_query(&client, sql, &types).await.unwrap();
let stmt2 = cache.get_by_query(&client, sql, &types).await.unwrap();
assert!(Arc::ptr_eq(&stmt1, &stmt2), "stmt1 and stmt2 should be the same Arc");

// replace our cached query with a new one going over the capacity
cache.get_by_query(&client, sql, &[Type::INT8]).await.unwrap();

// the old query should be evicted from the cache
let stmt3 = cache.get_by_query(&client, sql, &types).await.unwrap();
assert!(
!Arc::ptr_eq(&stmt1, &stmt3),
"stmt1 and stmt3 should not be the same Arc"
);
})
.await;
}

#[tokio::test]
async fn tracing_cache_reuses_queries_with_different_traceparent() {
run_with_client(|client| async move {
let cache = LruTracingCache::with_capacity(1);
let sql1 = "SELECT $1 /* traceparent=00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01 */";
let sql2 = "SELECT $1 /* traceparent=00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-02 */";
let types = [Type::INT4];

let stmt1 = cache.get_by_query(&client, sql1, &types).await.unwrap();
let stmt2 = cache.get_by_query(&client, sql2, &types).await.unwrap();
assert!(Arc::ptr_eq(&stmt1, &stmt2), "stmt1 and stmt2 should be the same Arc");
})
.await;
}

async fn run_with_client<Func, Fut>(test: Func)
where
Func: FnOnce(Client) -> Fut,
Fut: Future<Output = ()>,
{
let url = Url::parse(&CONN_STR).unwrap();
let mut pg_url = PostgresNativeUrl::new(url).unwrap();
pg_url.set_flavour(PostgresFlavour::Postgres);

let tls_manager = MakeTlsConnectorManager::new(pg_url.clone());
let tls = tls_manager.get_connector().await.unwrap();

let (client, conn) = pg_url.to_config().connect(tls).await.unwrap();

let set = tokio::task::LocalSet::new();
set.spawn_local(conn);
set.run_until(test(client)).await
}
}
12 changes: 7 additions & 5 deletions quaint/src/connector/postgres/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,17 @@ pub struct PostgreSql<QueriesCache, StmtsCache> {
db_system_name: &'static str,
}

/// A Postgres client with the default caching strategy, which involves storing everything as
/// prepared statements in an LRU cache.
/// A [`PostgreSql`] interface with the default caching strategy, which involves storing all
/// queries as prepared statements in an LRU cache.
pub type PostgreSqlWithDefaultCache = PostgreSql<LruPreparedStatementCache, LruPreparedStatementCache>;

/// A Postgres client which executes all queries as prepared statements without caching.
/// A [`PostgreSql`] interface which executes all queries as prepared statements without caching
/// them.
pub type PostgreSqlWithNoCache = PostgreSql<NoopPreparedStatementCache, NoopPreparedStatementCache>;

/// A Postgres client with a caching strategy dedicated to query tracing, which involves storing
/// query type information in a dedicated LRU cache and not re-using any prepared statements.
/// A [`PostgreSql`] interface with the tracing caching strategy, which involves storing query
/// type information in a dedicated LRU cache for applicable queries and not re-using any prepared
/// statements.
pub type PostgreSqlWithTracingCache = PostgreSql<LruTracingCache, NoopPreparedStatementCache>;

#[derive(Debug)]
Expand Down
120 changes: 116 additions & 4 deletions quaint/src/connector/postgres/native/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,22 @@ impl PreparedQuery for Statement {
/// A query combined with the relevant type information about its parameters and columns.
#[derive(Debug)]
pub struct TypedQuery {
pub(super) sql: String,
pub(super) param_types: Vec<Type>,
pub(super) column_names: Vec<String>,
pub(super) column_types: Vec<Type>,
sql: String,
param_types: Vec<Type>,
column_names: Vec<String>,
column_types: Vec<Type>,
}

impl TypedQuery {
/// Create a new typed query from a SQL string and a statement.
pub fn from_statement(sql: impl Into<String>, statement: &Statement) -> Self {
Self {
sql: sql.into(),
param_types: statement.params().to_vec(),
column_names: statement.columns().iter().map(|c| c.name().to_owned()).collect(),
column_types: statement.columns().iter().map(|c| c.type_().clone()).collect(),
}
}
}

#[async_trait]
Expand Down Expand Up @@ -105,3 +117,103 @@ impl<A: PreparedQuery + Sync> PreparedQuery for Arc<A> {
self.as_ref().dispatch(client, args).await
}
}

#[cfg(test)]
mod tests {
use super::*;

use std::future::Future;

pub(crate) use crate::connector::postgres::url::PostgresNativeUrl;
use crate::{
connector::{MakeTlsConnectorManager, PostgresFlavour},
tests::test_api::postgres::CONN_STR,
};
use url::Url;

#[tokio::test]
async fn typed_query_matches_statement_and_dispatches() {
run_with_client(|client| async move {
let query = "SELECT $1";
let stmt = client.prepare_typed(query, &[Type::INT4]).await.unwrap();
let typed = TypedQuery::from_statement(query, &stmt);

assert_eq!(typed.param_types().cloned().collect::<Vec<_>>(), stmt.params());
assert_eq!(
typed.column_names().collect::<Vec<_>>(),
stmt.columns().iter().map(|c| c.name()).collect::<Vec<_>>()
);
assert_eq!(
typed.column_types().collect::<Vec<_>>(),
stmt.columns().iter().map(|c| c.type_()).collect::<Vec<_>>()
);

let result = typed.dispatch(&client, &[&1i32]).await;
assert!(result.is_ok(), "{:?}", result.err());
})
.await;
}

#[tokio::test]
async fn statement_trait_methods_match_statement_and_dispatch() {
run_with_client(|client| async move {
let query = "SELECT $1";
let stmt = client.prepare_typed(query, &[Type::INT4]).await.unwrap();

assert_eq!(stmt.param_types().cloned().collect::<Vec<_>>(), stmt.params());
assert_eq!(
stmt.column_names().collect::<Vec<_>>(),
stmt.columns().iter().map(|c| c.name()).collect::<Vec<_>>()
);
assert_eq!(
stmt.column_types().collect::<Vec<_>>(),
stmt.columns().iter().map(|c| c.type_()).collect::<Vec<_>>()
);

let result = stmt.dispatch(&client, &[&1i32]).await;
assert!(result.is_ok(), "{:?}", result.err());
})
.await;
}

#[tokio::test]
async fn arc_trait_methods_match_statement_and_dispatch() {
run_with_client(|client| async move {
let query = "SELECT $1";
let stmt = Arc::new(client.prepare_typed(query, &[Type::INT4]).await.unwrap());

assert_eq!(stmt.param_types().cloned().collect::<Vec<_>>(), stmt.params());
assert_eq!(
stmt.column_names().collect::<Vec<_>>(),
stmt.columns().iter().map(|c| c.name()).collect::<Vec<_>>()
);
assert_eq!(
stmt.column_types().collect::<Vec<_>>(),
stmt.columns().iter().map(|c| c.type_()).collect::<Vec<_>>()
);

let result = stmt.dispatch(&client, &[&1i32]).await;
assert!(result.is_ok(), "{:?}", result.err());
})
.await;
}

async fn run_with_client<Func, Fut>(test: Func)
where
Func: FnOnce(Client) -> Fut,
Fut: Future<Output = ()>,
{
let url = Url::parse(&CONN_STR).unwrap();
let mut pg_url = PostgresNativeUrl::new(url).unwrap();
pg_url.set_flavour(PostgresFlavour::Postgres);

let tls_manager = MakeTlsConnectorManager::new(pg_url.clone());
let tls = tls_manager.get_connector().await.unwrap();

let (client, conn) = pg_url.to_config().connect(tls).await.unwrap();

let set = tokio::task::LocalSet::new();
set.spawn_local(conn);
set.run_until(test(client)).await
}
}

0 comments on commit 3e64697

Please sign in to comment.