Skip to content

Commit

Permalink
sqlite-input-improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jeromegn committed Oct 23, 2023
1 parent 94bd76a commit 36d8254
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 44 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ rustls = { version = "0.21.0", features = ["dangerous_configuration", "quic"] }
rustls-pemfile = "1.0.2"
seahash = "4.1.0"
serde = "1.0.159"
serde_json = "1.0.95"
serde_json = { version = "1.0.95", features = ["raw_value"] }
serde_with = "2.3.2"
smallvec = { version = "1.11.0", features = ["serde", "write", "union"] }
speedy = { version = "0.8.7", features = ["uuid", "smallvec"], package = "corro-speedy" }
Expand Down
50 changes: 34 additions & 16 deletions crates/corro-agent/src/api/public/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,25 @@ where

#[tracing::instrument(skip_all, err)]
fn execute_statement(tx: &Transaction, stmt: &Statement) -> rusqlite::Result<usize> {
let mut prepped = match &stmt {
Statement::Simple(q) => tx.prepare(q),
Statement::WithParams(q, _) => tx.prepare(q),
Statement::WithNamedParams(q, _) => tx.prepare(q),
}?;
let mut prepped = tx.prepare(stmt.query())?;

match stmt {
Statement::Simple(_) => prepped.execute([]),
Statement::WithParams(_, params) => prepped.execute(params_from_iter(params)),
Statement::WithNamedParams(_, params) => prepped.execute(
Statement::Simple(_)
| Statement::Verbose {
params: None,
named_params: None,
..
} => prepped.execute([]),
Statement::WithParams(_, params)
| Statement::Verbose {
params: Some(params),
..
} => prepped.execute(params_from_iter(params)),
Statement::WithNamedParams(_, params)
| Statement::Verbose {
named_params: Some(params),
..
} => prepped.execute(
params
.iter()
.map(|(k, v)| (k.as_str(), v as &dyn ToSql))
Expand Down Expand Up @@ -429,11 +438,7 @@ async fn build_query_rows_response(
}
};

let prepped_res = block_in_place(|| match &stmt {
Statement::Simple(q) => conn.prepare(q),
Statement::WithParams(q, _) => conn.prepare(q),
Statement::WithNamedParams(q, _) => conn.prepare(q),
});
let prepped_res = block_in_place(|| conn.prepare(stmt.query()));

let mut prepped = match prepped_res {
Ok(prepped) => prepped,
Expand Down Expand Up @@ -476,9 +481,22 @@ async fn build_query_rows_response(
let start = Instant::now();

let query = match stmt {
Statement::Simple(_) => prepped.query(()),
Statement::WithParams(_, params) => prepped.query(params_from_iter(params)),
Statement::WithNamedParams(_, params) => prepped.query(
Statement::Simple(_)
| Statement::Verbose {
params: None,
named_params: None,
..
} => prepped.query(()),
Statement::WithParams(_, params)
| Statement::Verbose {
params: Some(params),
..
} => prepped.query(params_from_iter(params)),
Statement::WithNamedParams(_, params)
| Statement::Verbose {
named_params: Some(params),
..
} => prepped.query(
params
.iter()
.map(|(k, v)| (k.as_str(), v as &dyn ToSql))
Expand Down
25 changes: 20 additions & 5 deletions crates/corro-agent/src/api/public/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,31 @@ pub async fn process_sub_channel(

fn expanded_statement(conn: &Connection, stmt: &Statement) -> rusqlite::Result<Option<String>> {
Ok(match stmt {
Statement::Simple(q) => conn.prepare(q)?.expanded_sql(),
Statement::WithParams(q, params) => {
let mut prepped = conn.prepare(q)?;
Statement::Simple(query)
| Statement::Verbose {
query,
params: None,
named_params: None,
} => conn.prepare(query)?.expanded_sql(),
Statement::WithParams(query, params)
| Statement::Verbose {
query,
params: Some(params),
..
} => {
let mut prepped = conn.prepare(query)?;
for (i, param) in params.iter().enumerate() {
prepped.raw_bind_parameter(i + 1, param)?;
}
prepped.expanded_sql()
}
Statement::WithNamedParams(q, params) => {
let mut prepped = conn.prepare(q)?;
Statement::WithNamedParams(query, params)
| Statement::Verbose {
query,
named_params: Some(params),
..
} => {
let mut prepped = conn.prepare(query)?;
for (k, v) in params.iter() {
let idx = match prepped.parameter_index(k)? {
Some(idx) => idx,
Expand Down
3 changes: 2 additions & 1 deletion crates/corro-api-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ compact_str = { workspace = true }
hex = { workspace = true }
rusqlite = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
smallvec = { workspace = true }
speedy = { workspace = true }
strum = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio = { workspace = true }
127 changes: 125 additions & 2 deletions crates/corro-api-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rusqlite::{
Row, ToSql,
};
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use smallvec::{SmallVec, ToSmallVec};
use speedy::{Context, Readable, Reader, Writable, Writer};
use sqlite::ChangeType;
Expand Down Expand Up @@ -120,9 +121,25 @@ impl ToSql for ChangeId {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Statement {
Verbose {
query: String,
params: Option<Vec<SqliteParam>>,
named_params: Option<HashMap<String, SqliteParam>>,
},
Simple(String),
WithParams(String, Vec<SqliteValue>),
WithNamedParams(String, HashMap<String, SqliteValue>),
WithParams(String, Vec<SqliteParam>),
WithNamedParams(String, HashMap<String, SqliteParam>),
}

impl Statement {
pub fn query(&self) -> &str {
match self {
Statement::Verbose { query, .. }
| Statement::Simple(query)
| Statement::WithParams(query, _)
| Statement::WithNamedParams(query, _) => query,
}
}
}

impl From<&str> for Statement {
Expand Down Expand Up @@ -292,6 +309,76 @@ impl FromSql for ColumnType {
}
}

#[allow(clippy::large_enum_variant)]
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SqliteParam {
#[default]
Null,
Bool(bool),
Integer(i64),
Real(f64),
Text(CompactString),
Blob(SmallVec<[u8; 512]>),
Json(Box<RawValue>),
}

impl From<&str> for SqliteParam {
fn from(value: &str) -> Self {
Self::Text(value.into())
}
}

impl From<Vec<u8>> for SqliteParam {
fn from(value: Vec<u8>) -> Self {
Self::Blob(value.into())
}
}

impl From<String> for SqliteParam {
fn from(value: String) -> Self {
Self::Text(value.into())
}
}

impl From<u16> for SqliteParam {
fn from(value: u16) -> Self {
Self::Integer(value as i64)
}
}

impl From<i64> for SqliteParam {
fn from(value: i64) -> Self {
Self::Integer(value)
}
}

impl ToSql for SqliteParam {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
Ok(match self {
SqliteParam::Null => ToSqlOutput::Owned(Value::Null),
SqliteParam::Bool(v) => ToSqlOutput::Owned(Value::Integer(*v as i64)),
SqliteParam::Integer(i) => ToSqlOutput::Owned(Value::Integer(*i)),
SqliteParam::Real(f) => ToSqlOutput::Owned(Value::Real(*f)),
SqliteParam::Text(t) => ToSqlOutput::Borrowed(ValueRef::Text(t.as_bytes())),
SqliteParam::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)),
SqliteParam::Json(map) => ToSqlOutput::Borrowed(ValueRef::Text(map.get().as_bytes())),
})
}
}

impl<'a> ToSql for SqliteValueRef<'a> {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'a>> {
Ok(match self {
SqliteValueRef::Null => ToSqlOutput::Owned(Value::Null),
SqliteValueRef::Integer(i) => ToSqlOutput::Owned(Value::Integer(*i)),
SqliteValueRef::Real(f) => ToSqlOutput::Owned(Value::Real(*f)),
SqliteValueRef::Text(t) => ToSqlOutput::Borrowed(ValueRef::Text(t.as_bytes())),
SqliteValueRef::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)),
})
}
}

#[allow(clippy::large_enum_variant)]
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Hash)]
#[serde(untagged)]
Expand Down Expand Up @@ -655,3 +742,39 @@ impl ToSql for ColumnName {
self.0.as_str().to_sql()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_statement_serialization() {
let s = serde_json::to_string(&vec![Statement::WithParams(
"select 1
from table
where column = ?"
.into(),
vec!["my-value".into()],
)])
.unwrap();
println!("{s}");

let stmts: Vec<Statement> = serde_json::from_str(&s).unwrap();
println!("stmts: {stmts:?}");

let json = r#"[["some statement",[1,"encodedID","nodeName",1,"Name","State",true,true,"",1234,1698084893487,1698084893487]]]"#;

let value: serde_json::Value = serde_json::from_str(json).unwrap();
println!("value: {value:#?}");

let stmts: Vec<Statement> = serde_json::from_str(json).unwrap();
println!("stmts: {stmts:?}");

let json = r#"[{"query": "some statement", "params": [1,"encodedID","nodeName",1,"Name","State",true,true,"",1234,1698084893487,1698084893487]}]"#;
let value: serde_json::Value = serde_json::from_str(json).unwrap();
println!("value: {value:#?}");

let stmts: Vec<Statement> = serde_json::from_str(json).unwrap();
println!("stmts: {stmts:?}");
}
}
28 changes: 12 additions & 16 deletions crates/corro-tpl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use compact_str::ToCompactString;
use corro_client::sub::SubscriptionStream;
use corro_client::CorrosionApiClient;
use corro_types::api::QueryEvent;
use corro_types::api::SqliteParam;
use corro_types::api::Statement;
use corro_types::change::SqliteValue;
use futures::StreamExt;
Expand Down Expand Up @@ -536,33 +537,28 @@ impl Engine {
}
});

fn dyn_to_sql(v: Dynamic) -> Result<SqliteValue, Box<EvalAltResult>> {
fn dyn_to_sql(v: Dynamic) -> Result<SqliteParam, Box<EvalAltResult>> {
Ok(match v.type_name() {
"()" => SqliteValue::Null,
"i64" => SqliteValue::Integer(
"()" => SqliteParam::Null,
"i64" => SqliteParam::Integer(
v.as_int()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to i64")))?,
),
"f64" => SqliteValue::Real(corro_types::api::Real(
"f64" => SqliteParam::Real(
v.as_float()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to f64")))?,
)),
"bool" => {
if v.as_bool()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to bool")))?
{
SqliteValue::Integer(1)
} else {
SqliteValue::Integer(0)
}
}
"blob" => SqliteValue::Blob(
),
"bool" => SqliteParam::Bool(
v.as_bool()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to bool")))?,
),
"blob" => SqliteParam::Blob(
v.into_blob()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to blob")))?
.into(),
),
// convert everything else into a string, including a string
_ => SqliteValue::Text(v.to_compact_string()),
_ => SqliteParam::Text(v.to_compact_string()),
})
}

Expand Down
6 changes: 3 additions & 3 deletions crates/corrosion/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use command::{
tls::{generate_ca, generate_client_cert, generate_server_cert},
tpl::TemplateFlags,
};
use corro_api_types::SqliteValue;
use corro_api_types::{SqliteParam, SqliteValue};
use corro_client::CorrosionApiClient;
use corro_types::{
api::{ExecResult, QueryEvent, Statement},
Expand Down Expand Up @@ -301,7 +301,7 @@ async fn process_cli(cli: Cli) -> eyre::Result<()> {
} else {
Statement::WithParams(
query.clone(),
param.iter().map(|p| SqliteValue::Text(p.into())).collect(),
param.iter().map(|p| SqliteParam::Text(p.into())).collect(),
)
};

Expand Down Expand Up @@ -359,7 +359,7 @@ async fn process_cli(cli: Cli) -> eyre::Result<()> {
} else {
Statement::WithParams(
query.clone(),
param.iter().map(|p| SqliteValue::Text(p.into())).collect(),
param.iter().map(|p| SqliteParam::Text(p.into())).collect(),
)
};

Expand Down

0 comments on commit 36d8254

Please sign in to comment.