Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve SQLite-typed input for API #82

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ rustls = { version = "0.21.0", features = ["dangerous_configuration", "quic"] }
rustls-pemfile = "1.0.2"
seahash = "4.1.0"
serde = "1.0.159"
serde_json = "1.0.95"
serde_json = { version = "1.0.95", features = ["raw_value"] }
serde_with = "2.3.2"
smallvec = { version = "1.11.0", features = ["serde", "write", "union"] }
speedy = { version = "0.8.7", features = ["uuid", "smallvec"], package = "corro-speedy" }
Expand Down
2 changes: 1 addition & 1 deletion crates/corro-agent/src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2736,7 +2736,7 @@ pub mod tests {
use serde_json::json;
use spawn::wait_for_all_pending_handles;
use tokio::time::{sleep, timeout, MissedTickBehavior};
use tracing::{info, info_span};
use tracing::info_span;
use tripwire::Tripwire;

use super::*;
Expand Down
50 changes: 34 additions & 16 deletions crates/corro-agent/src/api/public/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,16 +308,25 @@ where

#[tracing::instrument(skip_all, err)]
fn execute_statement(tx: &Transaction, stmt: &Statement) -> rusqlite::Result<usize> {
let mut prepped = match &stmt {
Statement::Simple(q) => tx.prepare(q),
Statement::WithParams(q, _) => tx.prepare(q),
Statement::WithNamedParams(q, _) => tx.prepare(q),
}?;
let mut prepped = tx.prepare(stmt.query())?;

match stmt {
Statement::Simple(_) => prepped.execute([]),
Statement::WithParams(_, params) => prepped.execute(params_from_iter(params)),
Statement::WithNamedParams(_, params) => prepped.execute(
Statement::Simple(_)
| Statement::Verbose {
params: None,
named_params: None,
..
} => prepped.execute([]),
Statement::WithParams(_, params)
| Statement::Verbose {
params: Some(params),
..
} => prepped.execute(params_from_iter(params)),
Statement::WithNamedParams(_, params)
| Statement::Verbose {
named_params: Some(params),
..
} => prepped.execute(
params
.iter()
.map(|(k, v)| (k.as_str(), v as &dyn ToSql))
Expand Down Expand Up @@ -429,11 +438,7 @@ async fn build_query_rows_response(
}
};

let prepped_res = block_in_place(|| match &stmt {
Statement::Simple(q) => conn.prepare(q),
Statement::WithParams(q, _) => conn.prepare(q),
Statement::WithNamedParams(q, _) => conn.prepare(q),
});
let prepped_res = block_in_place(|| conn.prepare(stmt.query()));

let mut prepped = match prepped_res {
Ok(prepped) => prepped,
Expand Down Expand Up @@ -476,9 +481,22 @@ async fn build_query_rows_response(
let start = Instant::now();

let query = match stmt {
Statement::Simple(_) => prepped.query(()),
Statement::WithParams(_, params) => prepped.query(params_from_iter(params)),
Statement::WithNamedParams(_, params) => prepped.query(
Statement::Simple(_)
| Statement::Verbose {
params: None,
named_params: None,
..
} => prepped.query(()),
Statement::WithParams(_, params)
| Statement::Verbose {
params: Some(params),
..
} => prepped.query(params_from_iter(params)),
Statement::WithNamedParams(_, params)
| Statement::Verbose {
named_params: Some(params),
..
} => prepped.query(
params
.iter()
.map(|(k, v)| (k.as_str(), v as &dyn ToSql))
Expand Down
25 changes: 20 additions & 5 deletions crates/corro-agent/src/api/public/pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,31 @@ pub async fn process_sub_channel(

fn expanded_statement(conn: &Connection, stmt: &Statement) -> rusqlite::Result<Option<String>> {
Ok(match stmt {
Statement::Simple(q) => conn.prepare(q)?.expanded_sql(),
Statement::WithParams(q, params) => {
let mut prepped = conn.prepare(q)?;
Statement::Simple(query)
| Statement::Verbose {
query,
params: None,
named_params: None,
} => conn.prepare(query)?.expanded_sql(),
Statement::WithParams(query, params)
| Statement::Verbose {
query,
params: Some(params),
..
} => {
let mut prepped = conn.prepare(query)?;
for (i, param) in params.iter().enumerate() {
prepped.raw_bind_parameter(i + 1, param)?;
}
prepped.expanded_sql()
}
Statement::WithNamedParams(q, params) => {
let mut prepped = conn.prepare(q)?;
Statement::WithNamedParams(query, params)
| Statement::Verbose {
query,
named_params: Some(params),
..
} => {
let mut prepped = conn.prepare(query)?;
for (k, v) in params.iter() {
let idx = match prepped.parameter_index(k)? {
Some(idx) => idx,
Expand Down
3 changes: 2 additions & 1 deletion crates/corro-api-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ compact_str = { workspace = true }
hex = { workspace = true }
rusqlite = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
smallvec = { workspace = true }
speedy = { workspace = true }
strum = { workspace = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio = { workspace = true }
127 changes: 125 additions & 2 deletions crates/corro-api-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use rusqlite::{
Row, ToSql,
};
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
use smallvec::{SmallVec, ToSmallVec};
use speedy::{Context, Readable, Reader, Writable, Writer};
use sqlite::ChangeType;
Expand Down Expand Up @@ -120,9 +121,25 @@ impl ToSql for ChangeId {
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Statement {
Verbose {
query: String,
params: Option<Vec<SqliteParam>>,
named_params: Option<HashMap<String, SqliteParam>>,
},
Simple(String),
WithParams(String, Vec<SqliteValue>),
WithNamedParams(String, HashMap<String, SqliteValue>),
WithParams(String, Vec<SqliteParam>),
WithNamedParams(String, HashMap<String, SqliteParam>),
}

impl Statement {
pub fn query(&self) -> &str {
match self {
Statement::Verbose { query, .. }
| Statement::Simple(query)
| Statement::WithParams(query, _)
| Statement::WithNamedParams(query, _) => query,
}
}
}

impl From<&str> for Statement {
Expand Down Expand Up @@ -292,6 +309,76 @@ impl FromSql for ColumnType {
}
}

#[allow(clippy::large_enum_variant)]
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SqliteParam {
#[default]
Null,
Bool(bool),
Integer(i64),
Real(f64),
Text(CompactString),
Blob(SmallVec<[u8; 512]>),
Json(Box<RawValue>),
}

impl From<&str> for SqliteParam {
fn from(value: &str) -> Self {
Self::Text(value.into())
}
}

impl From<Vec<u8>> for SqliteParam {
fn from(value: Vec<u8>) -> Self {
Self::Blob(value.into())
}
}

impl From<String> for SqliteParam {
fn from(value: String) -> Self {
Self::Text(value.into())
}
}

impl From<u16> for SqliteParam {
fn from(value: u16) -> Self {
Self::Integer(value as i64)
}
}

impl From<i64> for SqliteParam {
fn from(value: i64) -> Self {
Self::Integer(value)
}
}

impl ToSql for SqliteParam {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'_>> {
Ok(match self {
SqliteParam::Null => ToSqlOutput::Owned(Value::Null),
SqliteParam::Bool(v) => ToSqlOutput::Owned(Value::Integer(*v as i64)),
SqliteParam::Integer(i) => ToSqlOutput::Owned(Value::Integer(*i)),
SqliteParam::Real(f) => ToSqlOutput::Owned(Value::Real(*f)),
SqliteParam::Text(t) => ToSqlOutput::Borrowed(ValueRef::Text(t.as_bytes())),
SqliteParam::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)),
SqliteParam::Json(map) => ToSqlOutput::Borrowed(ValueRef::Text(map.get().as_bytes())),
})
}
}

impl<'a> ToSql for SqliteValueRef<'a> {
fn to_sql(&self) -> rusqlite::Result<ToSqlOutput<'a>> {
Ok(match self {
SqliteValueRef::Null => ToSqlOutput::Owned(Value::Null),
SqliteValueRef::Integer(i) => ToSqlOutput::Owned(Value::Integer(*i)),
SqliteValueRef::Real(f) => ToSqlOutput::Owned(Value::Real(*f)),
SqliteValueRef::Text(t) => ToSqlOutput::Borrowed(ValueRef::Text(t.as_bytes())),
SqliteValueRef::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)),
})
}
}

#[allow(clippy::large_enum_variant)]
#[derive(Debug, Default, Clone, Serialize, Deserialize, PartialEq, Hash)]
#[serde(untagged)]
Expand Down Expand Up @@ -655,3 +742,39 @@ impl ToSql for ColumnName {
self.0.as_str().to_sql()
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_statement_serialization() {
let s = serde_json::to_string(&vec![Statement::WithParams(
"select 1
from table
where column = ?"
.into(),
vec!["my-value".into()],
)])
.unwrap();
println!("{s}");

let stmts: Vec<Statement> = serde_json::from_str(&s).unwrap();
println!("stmts: {stmts:?}");

let json = r#"[["some statement",[1,"encodedID","nodeName",1,"Name","State",true,true,"",1234,1698084893487,1698084893487]]]"#;

let value: serde_json::Value = serde_json::from_str(json).unwrap();
println!("value: {value:#?}");

let stmts: Vec<Statement> = serde_json::from_str(json).unwrap();
println!("stmts: {stmts:?}");

let json = r#"[{"query": "some statement", "params": [1,"encodedID","nodeName",1,"Name","State",true,true,"",1234,1698084893487,1698084893487]}]"#;
let value: serde_json::Value = serde_json::from_str(json).unwrap();
println!("value: {value:#?}");

let stmts: Vec<Statement> = serde_json::from_str(json).unwrap();
println!("stmts: {stmts:?}");
}
}
28 changes: 12 additions & 16 deletions crates/corro-tpl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use compact_str::ToCompactString;
use corro_client::sub::SubscriptionStream;
use corro_client::CorrosionApiClient;
use corro_types::api::QueryEvent;
use corro_types::api::SqliteParam;
use corro_types::api::Statement;
use corro_types::change::SqliteValue;
use futures::StreamExt;
Expand Down Expand Up @@ -536,33 +537,28 @@ impl Engine {
}
});

fn dyn_to_sql(v: Dynamic) -> Result<SqliteValue, Box<EvalAltResult>> {
fn dyn_to_sql(v: Dynamic) -> Result<SqliteParam, Box<EvalAltResult>> {
Ok(match v.type_name() {
"()" => SqliteValue::Null,
"i64" => SqliteValue::Integer(
"()" => SqliteParam::Null,
"i64" => SqliteParam::Integer(
v.as_int()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to i64")))?,
),
"f64" => SqliteValue::Real(corro_types::api::Real(
"f64" => SqliteParam::Real(
v.as_float()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to f64")))?,
)),
"bool" => {
if v.as_bool()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to bool")))?
{
SqliteValue::Integer(1)
} else {
SqliteValue::Integer(0)
}
}
"blob" => SqliteValue::Blob(
),
"bool" => SqliteParam::Bool(
v.as_bool()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to bool")))?,
),
"blob" => SqliteParam::Blob(
v.into_blob()
.map_err(|_e| Box::new(EvalAltResult::from("could not cast to blob")))?
.into(),
),
// convert everything else into a string, including a string
_ => SqliteValue::Text(v.to_compact_string()),
_ => SqliteParam::Text(v.to_compact_string()),
})
}

Expand Down
6 changes: 3 additions & 3 deletions crates/corrosion/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use command::{
tls::{generate_ca, generate_client_cert, generate_server_cert},
tpl::TemplateFlags,
};
use corro_api_types::SqliteValue;
use corro_api_types::SqliteParam;
use corro_client::CorrosionApiClient;
use corro_types::{
api::{ExecResult, QueryEvent, Statement},
Expand Down Expand Up @@ -301,7 +301,7 @@ async fn process_cli(cli: Cli) -> eyre::Result<()> {
} else {
Statement::WithParams(
query.clone(),
param.iter().map(|p| SqliteValue::Text(p.into())).collect(),
param.iter().map(|p| SqliteParam::Text(p.into())).collect(),
)
};

Expand Down Expand Up @@ -359,7 +359,7 @@ async fn process_cli(cli: Cli) -> eyre::Result<()> {
} else {
Statement::WithParams(
query.clone(),
param.iter().map(|p| SqliteValue::Text(p.into())).collect(),
param.iter().map(|p| SqliteParam::Text(p.into())).collect(),
)
};

Expand Down