diff --git a/core/bin/dust_api.rs b/core/bin/dust_api.rs index 0eaf5566055a..c2ee3bba4733 100644 --- a/core/bin/dust_api.rs +++ b/core/bin/dust_api.rs @@ -22,7 +22,7 @@ use dust::{ project::{self}, providers::provider::{provider, ProviderID}, run, - sqlite_workers::sqlite_workers, + sqlite_workers::client, stores::postgres, stores::store, utils::{self, error_response, APIError, APIResponse}, @@ -142,7 +142,7 @@ impl APIState { let store = self.store.clone(); tokio::task::spawn(async move { match store - .sqlite_workers_cleanup(sqlite_workers::HEARTBEAT_INTERVAL_MS) + .sqlite_workers_cleanup(client::HEARTBEAT_INTERVAL_MS) .await { Err(e) => { @@ -2049,31 +2049,57 @@ async fn databases_rows_retrieve( String, String, )>, - extract::Extension(state): extract::Extension>, ) -> (StatusCode, Json) { let project = project::Project::new_from_id(project_id); match state .store - .load_database_row(&project, &data_source_id, &database_id, &table_id, &row_id) + .load_database(&project, &data_source_id, &database_id) .await { Err(e) => error_response( StatusCode::INTERNAL_SERVER_ERROR, "internal_server_error", - "Failed to upsert database rows", + "Failed to retrieve database", Some(e), ), - Ok(row) => ( - StatusCode::OK, - Json(APIResponse { - error: None, - response: Some(json!({ - "row": row - })), - }), - ), + Ok(db) => match db { + None => error_response( + StatusCode::NOT_FOUND, + "database_not_found", + &format!("No database found for id `{}`", database_id), + None, + ), + Some(db) => match db.sqlite_worker(state.store.clone()).await { + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + &format!("Failed to retrieve SQLite worker: {}", e), + Some(e), + ), + Ok(worker) => match worker + .get_row(db.unique_id().as_str(), &table_id, &row_id) + .await + { + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + &format!("Failed to retrieve row: {}", e), + Some(e), + ), + Ok(row) => ( + StatusCode::OK, + Json(APIResponse { + error: None, + response: Some(json!({ + "row": row, + })), + }), + ), + }, + }, + }, } } @@ -2097,33 +2123,56 @@ async fn databases_rows_list( match state .store - .list_database_rows( - &project, - &data_source_id, - &database_id, - &table_id, - Some((query.limit, query.offset)), - ) + .load_database(&project, &data_source_id, &database_id) .await { Err(e) => error_response( StatusCode::INTERNAL_SERVER_ERROR, "internal_server_error", - "Failed to list database rows", + "Failed to retrieve database", Some(e), ), - Ok((rows, total)) => ( - StatusCode::OK, - Json(APIResponse { - error: None, - response: Some(json!({ - "rows": rows, - "offset": query.offset, - "limit": query.limit, - "total": total, - })), - }), + Ok(None) => error_response( + StatusCode::NOT_FOUND, + "database_not_found", + &format!("No database found for id `{}`", database_id), + None, ), + Ok(Some(db)) => match db.sqlite_worker(state.store.clone()).await { + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + &format!("Failed to retrieve SQLite worker: {}", e), + Some(e), + ), + Ok(worker) => match worker + .get_rows( + db.unique_id().as_str(), + &table_id, + Some((query.limit, query.offset)), + ) + .await + { + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + &format!("Failed to list rows: {}", e), + Some(e), + ), + Ok((rows, total)) => ( + StatusCode::OK, + Json(APIResponse { + error: None, + response: Some(json!({ + "offset": query.offset, + "limit": query.limit, + "total": total, + "rows": rows, + })), + }), + ), + }, + }, } } @@ -2179,7 +2228,7 @@ async fn databases_query_run( // SQLite Workers -async fn sqlite_workers_hearbeat( +async fn sqlite_workers_heartbeat( extract::Path(pod_name): extract::Path, extract::Extension(state): extract::Extension>, ) -> (StatusCode, Json) { @@ -2299,7 +2348,7 @@ fn main() { }; let state = Arc::new(APIState::new(store, QdrantClients::build().await?)); - let app = Router::new() + let router = Router::new() // Index .route("/", get(index)) @@ -2434,7 +2483,6 @@ fn main() { "/projects/:project_id/data_sources/:data_source_id/databases/:database_id/query", post(databases_query_run), ) - .route("/sqlite_workers/:pod_name", post(sqlite_workers_hearbeat)) .route("/sqlite_workers/:pod_name", delete(sqlite_workers_delete)) // Misc .route("/tokenize", post(tokenize)) @@ -2448,6 +2496,13 @@ fn main() { ) .layer(extract::Extension(state.clone())); + // In a separate router, to avoid noisy tracing. + let sqlite_heartbeat_router = Router::new() + .route("/sqlite_workers/:pod_name", post(sqlite_workers_heartbeat)) + .layer(extract::Extension(state.clone())); + + let app = Router::new().merge(router).merge(sqlite_heartbeat_router); + // Start the APIState run loop. let runloop_state = state.clone(); tokio::task::spawn(async move { runloop_state.run_loop().await }); diff --git a/core/bin/sqlite_worker.rs b/core/bin/sqlite_worker.rs index 499bb731e588..3633f7b21db5 100644 --- a/core/bin/sqlite_worker.rs +++ b/core/bin/sqlite_worker.rs @@ -4,6 +4,7 @@ use std::{ atomic::{AtomicBool, Ordering}, Arc, }, + time::{Duration, Instant}, }; use anyhow::{anyhow, Result}; @@ -13,9 +14,11 @@ use axum::{ Extension, Json, Router, }; use dust::{ - sqlite_workers::sqlite_database::SqliteDatabase, + databases::database::DatabaseRow, + sqlite_workers::{sqlite_database::SqliteDatabase, store}, utils::{self, error_response, APIResponse}, }; +use dust::{databases::database::DatabaseTable, sqlite_workers::store::DatabasesStore}; use hyper::{Body, Client, Request, StatusCode}; use serde::Deserialize; use serde_json::json; @@ -26,14 +29,26 @@ use tokio::{ use tower_http::trace::{self, TraceLayer}; use tracing::Level; +// Duration after which a database is considered inactive and can be removed from the registry. +const DATABASE_TIMEOUT_DURATION: Duration = std::time::Duration::from_secs(5 * 60); // 5 minutes + +struct DatabaseEntry { + database: SqliteDatabase, + last_accessed: Instant, +} + struct WorkerState { - registry: Arc>>, + databases_store: Box, + + registry: Arc>>, is_shutting_down: Arc, } impl WorkerState { - fn new() -> Self { + fn new(databases_store: Box) -> Self { Self { + databases_store: databases_store, + // TODO: store an instant of the last access for each DB. registry: Arc::new(Mutex::new(HashMap::new())), is_shutting_down: Arc::new(AtomicBool::new(false)), @@ -47,10 +62,12 @@ impl WorkerState { } match self.heartbeat().await { - Ok(_) => utils::info("Heartbeat sent."), + Ok(_) => (), Err(e) => utils::error(&format!("Failed to send heartbeat: {:?}", e)), } - // TODO: check for inactive DBs to kill. + + self.cleanup_inactive_databases().await; + tokio::time::sleep(std::time::Duration::from_millis(1024)).await; } } @@ -73,6 +90,11 @@ impl WorkerState { self._core_request("DELETE").await } + async fn cleanup_inactive_databases(&self) { + let mut registry = self.registry.lock().await; + registry.retain(|_, entry| entry.last_accessed.elapsed() < DATABASE_TIMEOUT_DURATION); + } + async fn _core_request(&self, method: &str) -> Result<()> { let hostname = match std::env::var("HOSTNAME") { Ok(hostname) => hostname, @@ -109,6 +131,7 @@ async fn index() -> &'static str { #[derive(Deserialize)] struct DbQueryBody { query: String, + tables: Vec, } async fn db_query( @@ -118,16 +141,20 @@ async fn db_query( ) -> (StatusCode, Json) { let mut registry = state.registry.lock().await; - let db = match registry.get(&db_id) { - Some(db) => db, - None => { - let db = SqliteDatabase::new(&db_id); - registry.insert(db_id.clone(), db); - registry.get(&db_id).unwrap() - } - }; + let entry = registry + .entry(db_id.clone()) + .or_insert_with(|| DatabaseEntry { + database: SqliteDatabase::new( + db_id.clone(), + payload.tables, + state.databases_store.clone(), + ), + last_accessed: Instant::now(), + }); - match db.query(payload.query).await { + entry.last_accessed = Instant::now(); + + match entry.database.query(payload.query).await { Ok(results) => ( axum::http::StatusCode::OK, Json(APIResponse { @@ -144,6 +171,123 @@ async fn db_query( } } +#[derive(serde::Deserialize)] +struct DatabasesRowsUpsertPayload { + rows: Vec, + truncate: Option, +} + +async fn databases_rows_upsert( + extract::Path((database_id, table_id)): extract::Path<(String, String)>, + extract::Json(payload): extract::Json, + Extension(state): Extension>, +) -> (StatusCode, Json) { + // Terminate the running DB thread if it exists. + let mut registry = state.registry.lock().await; + match registry.get(&database_id) { + Some(_) => { + // Removing the DB from the registry will terminate the thread once pending queries are + // finished. + registry.remove(&database_id); + } + None => (), + } + + let truncate = match payload.truncate { + Some(v) => v, + None => false, + }; + + match state + .databases_store + .batch_upsert_database_rows(&database_id, &table_id, &payload.rows, truncate) + .await + { + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + "Failed to upsert database rows", + Some(e), + ), + Ok(()) => ( + StatusCode::OK, + Json(APIResponse { + error: None, + response: Some(json!({ + "success": true + })), + }), + ), + } +} + +#[derive(serde::Deserialize)] +struct DatabasesRowsListQuery { + offset: Option, + limit: Option, +} + +async fn databases_rows_list( + extract::Path((database_id, table_id)): extract::Path<(String, String)>, + extract::Query(query): extract::Query, + Extension(state): Extension>, +) -> (StatusCode, Json) { + let limit_offset = match (query.limit, query.offset) { + (Some(limit), Some(offset)) => Some((limit, offset)), + _ => None, + }; + match state + .databases_store + .list_database_rows(&database_id, &table_id, limit_offset) + .await + { + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + "Failed to list database rows", + Some(e), + ), + Ok((rows, total)) => ( + StatusCode::OK, + Json(APIResponse { + error: None, + response: Some(json!({ + "rows": rows, + "total": total, + })), + }), + ), + } +} + +async fn databases_row_retrieve( + extract::Path((database_id, table_id)): extract::Path<(String, String)>, + extract::Path(row_id): extract::Path, + Extension(state): Extension>, +) -> (StatusCode, Json) { + match state + .databases_store + .load_database_row(&database_id, &table_id, &row_id) + .await + { + Err(e) => error_response( + StatusCode::INTERNAL_SERVER_ERROR, + "internal_server_error", + "Failed to retrieve database row", + Some(e), + ), + Ok(row) => ( + StatusCode::OK, + Json(APIResponse { + error: None, + response: Some(json!({ + "row": row, + })), + }), + ), + } +} + fn main() { let rt = tokio::runtime::Builder::new_multi_thread() .worker_threads(32) @@ -158,10 +302,33 @@ fn main() { .with_ansi(false) .init(); - let state = Arc::new(WorkerState::new()); + let databases_store: Box = + match std::env::var("DATABASES_STORE_DATABASE_URI") { + Ok(db_uri) => { + let s = store::PostgresDatabasesStore::new(&db_uri).await?; + s.init().await?; + Box::new(s) + } + Err(_) => Err(anyhow!("DATABASES_STORE_DATABASE_URI not set."))?, + }; + + let state = Arc::new(WorkerState::new(databases_store)); + let app = Router::new() .route("/", get(index)) - .route("/db/:db_id", post(db_query)) + .route("/databases/:database_id", post(db_query)) + .route( + "/databases/:database_id/tables/:table_id/rows", + post(databases_rows_upsert), + ) + .route( + "/databases/:database_id/tables/:table_id/rows", + get(databases_rows_list), + ) + .route( + "/databases/:database_id/tables/:table_id/rows/:row_id", + get(databases_row_retrieve), + ) .layer( TraceLayer::new_for_http() .make_span_with(trace::DefaultMakeSpan::new().level(Level::INFO)) diff --git a/core/src/databases/database.rs b/core/src/databases/database.rs index ff1480ef9eac..0f628917e080 100644 --- a/core/src/databases/database.rs +++ b/core/src/databases/database.rs @@ -1,10 +1,13 @@ use super::table_schema::TableSchema; -use crate::{project::Project, stores::store::Store, utils}; +use crate::{ + project::Project, + sqlite_workers::client::{SqliteWorker, HEARTBEAT_INTERVAL_MS}, + stores::store::Store, + utils, +}; use anyhow::{anyhow, Result}; use futures::future::try_join_all; use itertools::Itertools; -use rayon::prelude::*; -use rusqlite::{params_from_iter, Connection}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -114,26 +117,25 @@ impl Database { } }; - try_join_all(vec![ - store.update_database_table_schema( + store + .update_database_table_schema( &self.project, &self.data_source_id, &self.database_id, table_id, &table_schema, - ), - store.batch_upsert_database_rows( - &self.project, - &self.data_source_id, - &self.database_id, - table_id, - &rows, - truncate, - ), - ]) - .await?; + ) + .await?; - Ok(()) + // Call the SqliteWorker to update the rows contents. + // Note: if this fails, the DB will still contain the new schema, but the rows will not be updated. + // This isn't too bad, because the merged schema is necessarily backward-compatible with the previous one. + // The other way around would not be true -- old schema doesn't necessarily work with the new rows. + // This is why we cannot `try_join_all`. + let sqlite_worker = self.sqlite_worker(store.clone()).await?; + sqlite_worker + .upsert_rows(&self.unique_id(), table_id, rows, truncate) + .await } pub async fn delete(&self, store: Box) -> Result<()> { @@ -149,98 +151,6 @@ impl Database { } } - pub async fn create_in_memory_sqlite_conn( - &self, - store: Box, - ) -> Result { - match self.db_type { - DatabaseType::REMOTE => Err(anyhow!( - "Cannot build an in-memory SQLite DB for a remote database." - )), - DatabaseType::LOCAL => { - let time_build_db_start = utils::now(); - - let tables = self.get_tables(store.clone()).await?; - utils::done(&format!( - "DSSTRUCTSTAT Finished retrieving schema: duration={}ms", - utils::now() - time_build_db_start - )); - - let time_get_rows_start = utils::now(); - let rows = self.get_rows(store.clone()).await?; - utils::done(&format!( - "DSSTRUCTSTAT Finished retrieving rows: duration={}ms", - utils::now() - time_get_rows_start - )); - - let generate_create_table_sql_start = utils::now(); - let create_tables_sql: String = tables - .into_iter() - .filter_map(|t| match t.schema() { - Some(s) => { - if s.is_empty() { - None - } else { - Some(s.get_create_table_sql_string(t.name())) - } - } - None => None, - }) - .collect::>() - .join("\n"); - utils::done(&format!( - "DSSTRUCTSTAT Finished generating create table SQL: duration={}ms", - utils::now() - generate_create_table_sql_start - )); - - let conn = rusqlite::Connection::open_in_memory()?; - - let create_tables_execute_start = utils::now(); - conn.execute_batch(&create_tables_sql)?; - utils::done(&format!( - "DSSTRUCTSTAT Finished creating tables: duration={}ms", - utils::now() - create_tables_execute_start - )); - - let insert_execute_start = utils::now(); - rows.iter() - .filter(|(_, rows)| !rows.is_empty()) - .map(|(table, rows)| { - if table.schema().is_none() { - Err(anyhow!("No schema found for table {}", table.name()))?; - } - let table_schema = table.schema().unwrap(); - let (sql, field_names) = table_schema.get_insert_sql(table.name()); - let mut stmt = conn.prepare(&sql)?; - - rows.par_iter() - .map(|r| match table_schema.get_insert_params(&field_names, r) { - Ok(params) => Ok(params_from_iter(params)), - Err(e) => Err(anyhow!( - "Error getting insert params for row {}: {}", - r.row_id(), - e - )), - }) - .collect::>>()? - .into_iter() - .map(|params| match stmt.execute(params) { - Ok(_) => Ok(()), - Err(e) => Err(anyhow!("Error inserting row: {}", e)), - }) - .collect::>>() - }) - .collect::>>()?; - utils::done(&format!( - "DSSTRUCTSTAT Finished inserting rows: duration={}ms", - utils::now() - insert_execute_start - )); - - Ok(conn) - } - } - } - pub async fn query( &self, store: Box, @@ -249,74 +159,16 @@ impl Database { match self.db_type { DatabaseType::REMOTE => Err(anyhow!("Remote DB not implemented.")), DatabaseType::LOCAL => { - let conn = self.create_in_memory_sqlite_conn(store.clone()).await?; + let tables = self.get_tables(store.clone()).await?; + let sqlite_worker = self.sqlite_worker(store.clone()).await?; let time_query_start = utils::now(); + let result_rows = sqlite_worker + .execute_query(&self.unique_id(), tables, query) + .await?; - let mut stmt = conn.prepare(query)?; - - // copy the column names into a vector of strings - let column_names = stmt - .column_names() - .into_iter() - .map(|x| x.to_string()) - .collect::>(); - - // Execute the query and collect the results in a vector of serde_json::Value objects. - let result_rows = stmt - .query_and_then([], |row| { - column_names - .iter() - .enumerate() - .map(|(i, column_name)| { - Ok(( - column_name.clone(), - match row.get(i) { - Err(e) => Err(anyhow!( - "Failed to retrieve value for column {}: {}", - column_name, - e - )), - Ok(v) => match v { - rusqlite::types::Value::Integer(i) => { - Ok(serde_json::Value::Number(i.into())) - } - rusqlite::types::Value::Real(f) => { - match serde_json::Number::from_f64(f) { - Some(n) => Ok(serde_json::Value::Number(n)), - None => Err(anyhow!( - "Invalid float value for column {}", - column_name - )), - } - } - rusqlite::types::Value::Text(t) => { - Ok(serde_json::Value::String(t.clone())) - } - rusqlite::types::Value::Blob(b) => { - match String::from_utf8(b.clone()) { - Err(_) => Err(anyhow!( - "Invalid UTF-8 sequence for column {}", - column_name - )), - Ok(s) => Ok(serde_json::Value::String(s)), - } - } - rusqlite::types::Value::Null => { - Ok(serde_json::Value::Null) - } - }, - }?, - )) - }) - .collect::>() - })? - .collect::>>()? - .into_par_iter() - .map(|value| DatabaseResult { value }) - .collect::>(); utils::done(&format!( - "DSSTRUCTSTAT Finished executing user query: duration={}ms", + "DSSTRUCTSTAT Finished executing user query on worker: duration={}ms", utils::now() - time_query_start )); @@ -345,32 +197,21 @@ impl Database { .list_databases_tables(&self.project, &self.data_source_id, &self.database_id, None) .await?; - // Concurrently retrieve table rows. - Ok(futures::future::try_join_all( - tables - .into_iter() - .map(|table| { - let store = store.clone(); - - async move { - let (rows, _) = store - .list_database_rows( - &self.project, - self.data_source_id.as_str(), - self.database_id.as_str(), - table.table_id(), - None, - ) - .await?; - - Ok::<_, anyhow::Error>((table, rows)) - } - }) - .collect::>(), - ) - .await? - .into_iter() - .collect::>()) + // Get the SQLite worker for this database. + let sqlite_worker = &self.sqlite_worker(store.clone()).await?; + + Ok(try_join_all(tables.into_iter().map(|table| { + let database_id = self.unique_id(); + let table_id = table.table_id().to_string(); + + async move { + let (rows, _) = sqlite_worker + .get_rows(&database_id, &table_id, None) + .await?; + Ok::<_, anyhow::Error>((table, rows)) + } + })) + .await?) } // Getters @@ -386,9 +227,36 @@ impl Database { pub fn name(&self) -> &str { &self.name } + pub fn unique_id(&self) -> String { + format!( + "{}__{}__{}", + self.project.project_id(), + self.data_source_id, + self.database_id + ) + } + + pub async fn sqlite_worker(&self, store: Box) -> Result { + let worker = store + .assign_live_sqlite_worker_to_database( + &self.project, + &self.data_source_id, + &self.database_id, + HEARTBEAT_INTERVAL_MS, + ) + .await?; + + match worker.is_alive() { + true => Ok(worker), + false => Err(anyhow!( + "No live SQLite worker found for database {}", + self.database_id + )), + } + } } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize, Clone, Deserialize)] pub struct DatabaseTable { created: u64, database_id: String, @@ -484,7 +352,7 @@ impl HasValue for DatabaseRow { #[derive(Debug, Deserialize, Serialize, Clone)] pub struct DatabaseResult { - value: Value, + pub value: Value, } impl HasValue for DatabaseResult { diff --git a/core/src/lib.rs b/core/src/lib.rs index 8a7142a0f33c..6d7084cb96b2 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -61,8 +61,9 @@ pub mod blocks { } pub mod sqlite_workers { + pub mod client; pub mod sqlite_database; - pub mod sqlite_workers; + pub mod store; } pub mod deno { diff --git a/core/src/sqlite_workers/client.rs b/core/src/sqlite_workers/client.rs new file mode 100644 index 000000000000..309118133cd3 --- /dev/null +++ b/core/src/sqlite_workers/client.rs @@ -0,0 +1,237 @@ +use anyhow::{anyhow, Result}; +use hyper::{Body, Client, Request}; +use serde::Deserialize; +use serde_json::json; + +use crate::{ + databases::database::{DatabaseResult, DatabaseRow, DatabaseTable}, + utils, +}; + +pub const HEARTBEAT_INTERVAL_MS: u64 = 3_000; + +pub struct SqliteWorker { + last_heartbeat: u64, + pod_name: String, +} + +impl SqliteWorker { + pub fn new(pod_name: String, last_heartbeat: u64) -> Self { + Self { + last_heartbeat: last_heartbeat, + pod_name, + } + } + + pub fn is_alive(&self) -> bool { + let now = utils::now(); + let elapsed = now - self.last_heartbeat; + + elapsed < HEARTBEAT_INTERVAL_MS + } + + pub async fn upsert_rows( + &self, + database_unique_id: &str, + table_id: &str, + rows: Vec, + truncate: bool, + ) -> Result<()> { + let url = self.url()?; + let req = Request::builder() + .method("POST") + .uri(format!( + "{}/databases/{}/tables/{}/rows", + url, database_unique_id, table_id + )) + .header("Content-Type", "application/json") + .body(Body::from( + json!({ + "rows": rows, + "truncate": truncate, + }) + .to_string(), + ))?; + + let res = Client::new().request(req).await?; + + match res.status().as_u16() { + 200 => Ok(()), + s => Err(anyhow!( + "Failed to send rows to sqlite worker. Status: {}", + s + )), + } + } + + pub async fn get_rows( + &self, + database_unique_id: &str, + table_id: &str, + limit_offset: Option<(usize, usize)>, + ) -> Result<(Vec, usize)> { + let worker_url = self.url()?; + + let mut uri = format!( + "{}/databases/{}/tables/{}/rows", + worker_url, database_unique_id, table_id + ); + + if let Some((limit, offset)) = limit_offset { + uri = format!("{}?limit={}&offset={}", uri, limit, offset); + } + + let req = Request::builder() + .method("GET") + .uri(uri) + .header("Content-Type", "application/json") + .body(Body::empty())?; + + let res = Client::new().request(req).await?; + + #[derive(Deserialize)] + struct GetRowsResponse { + rows: Vec, + total: usize, + } + #[derive(Deserialize)] + struct GetRowsResponseBody { + error: Option, + response: Option, + } + + match res.status().as_u16() { + 200 => { + let body = hyper::body::to_bytes(res.into_body()).await?; + let res: GetRowsResponseBody = serde_json::from_slice(&body)?; + let (rows, total) = match res.error { + Some(e) => Err(anyhow!("Error retrieving rows: {}", e))?, + None => match res.response { + Some(r) => (r.rows, r.total), + None => Err(anyhow!("No rows found in response"))?, + }, + }; + + Ok((rows, total)) + } + s => Err(anyhow!( + "Failed to retrieve rows from sqlite worker. Status: {}", + s + ))?, + } + } + + pub async fn get_row( + &self, + database_unique_id: &str, + table_id: &str, + row_id: &str, + ) -> Result> { + let worker_url = self.url()?; + + let uri = format!( + "{}/databases/{}/tables/{}/rows/{}", + worker_url, database_unique_id, table_id, row_id + ); + + let req = Request::builder() + .method("GET") + .uri(uri) + .header("Content-Type", "application/json") + .body(Body::empty())?; + + let res = Client::new().request(req).await?; + + #[derive(Deserialize)] + struct GetRowResponseBody { + error: Option, + response: Option, + } + + match res.status().as_u16() { + 200 => { + let body = hyper::body::to_bytes(res.into_body()).await?; + let res: GetRowResponseBody = serde_json::from_slice(&body)?; + match res.error { + Some(e) => Err(anyhow!("Error retrieving row: {}", e))?, + None => match res.response { + Some(r) => Ok(Some(r)), + None => Ok(None), + }, + } + } + s => Err(anyhow!( + "Failed to retrieve row from sqlite worker. Status: {}", + s + ))?, + } + } + + pub async fn execute_query( + &self, + database_unique_id: &str, + tables: Vec, + query: &str, + ) -> Result> { + let worker_url = self.url()?; + + let req = Request::builder() + .method("POST") + .uri(format!("{}/databases/{}", worker_url, database_unique_id)) + .header("Content-Type", "application/json") + .body(Body::from( + json!({ + "tables": tables, + "query": query, + }) + .to_string(), + ))?; + + let res = Client::new().request(req).await?; + + #[derive(Deserialize)] + struct ExecuteQueryResponseBody { + error: Option, + response: Option>, + } + + match res.status().as_u16() { + 200 => { + let body = hyper::body::to_bytes(res.into_body()).await?; + let res: ExecuteQueryResponseBody = serde_json::from_slice(&body)?; + match res.error { + Some(e) => Err(anyhow!("Error executing query: {}", e))?, + None => match res.response { + Some(r) => Ok(r), + None => Err(anyhow!("No response found"))?, + }, + } + } + s => Err(anyhow!( + "Failed to execute query on sqlite worker. Status: {}", + s + ))?, + } + } + + pub fn url(&self) -> Result { + match std::env::var("IS_LOCAL_DEV") { + Ok(_) => return Ok("http://localhost:3005".to_string()), + Err(_) => (), + } + let cluster_namespace = match std::env::var("CLUSTER_NAMESPACE") { + Ok(n) => n, + Err(_) => Err(anyhow!("CLUSTER_NAMESPACE env var not set"))?, + }; + let core_sqlite_headless_service_name = + match std::env::var("CORE_SQLITE_HEADLESS_SERVICE_NAME") { + Ok(s) => s, + Err(_) => Err(anyhow!("CORE_SQLITE_HEADLESS_SERVICE_NAME env var not set"))?, + }; + + Ok(format!( + "http://{}.{}.{}.svc.cluster.local", + self.pod_name, core_sqlite_headless_service_name, cluster_namespace + )) + } +} diff --git a/core/src/sqlite_workers/sqlite_database.rs b/core/src/sqlite_workers/sqlite_database.rs index 41c8efb645c7..d16de2f2503b 100644 --- a/core/src/sqlite_workers/sqlite_database.rs +++ b/core/src/sqlite_workers/sqlite_database.rs @@ -1,14 +1,26 @@ +use crate::{ + databases::database::{DatabaseResult, DatabaseRow, DatabaseTable}, + utils, +}; use anyhow::{anyhow, Result}; -use rusqlite::Connection; -use tokio::sync::{ - mpsc::{self, Sender}, - oneshot, +use futures::future::try_join_all; +use rayon::prelude::*; +use rusqlite::{params_from_iter, Connection}; +use tokio::task; +use tokio::{ + runtime::Handle, + sync::{ + mpsc::{self, Sender}, + oneshot, + }, }; +use super::store::DatabasesStore; + pub enum DbMessage { Execute { query: String, - response: oneshot::Sender>>, + response: oneshot::Sender>>, }, } @@ -17,37 +29,36 @@ pub struct SqliteDatabase { } impl SqliteDatabase { - pub fn new(database_id: &str) -> Self { + pub fn new( + database_id: String, + tables: Vec, + databases_store: Box, + ) -> Self { let (tx, mut rx) = mpsc::channel(32); - let db_id_clone = database_id.to_string(); - tokio::spawn(async move { - // TODD: init code - let conn = Connection::open_in_memory().unwrap(); - while let Some(message) = rx.recv().await { + // We use a blocking thread because the DB thread will sometimes be CPU-heavy (creating the DB, executing the query). + let runtime_handle = tokio::runtime::Handle::current(); + + task::spawn_blocking(move || { + let conn = + create_in_memory_sqlite_db(runtime_handle, databases_store, database_id, &tables) + .unwrap(); + + while let Some(message) = rx.blocking_recv() { match message { DbMessage::Execute { query, response } => { - println!("Executing query: {} on db: {}", query, db_id_clone); - // Execute the query and collect results - let mut stmt = conn.prepare(&query).unwrap(); - let rows = stmt.query_map([], |row| row.get(0)).unwrap(); - - let mut results = Vec::new(); - for value in rows { - results.push(value.unwrap()); - } - - // Send the results back through the oneshot channel - let _ = response.send(Ok(results)); + let _ = response.send(Ok(execute_query_on_conn(&conn, query)?)); } } } + + Ok::<(), anyhow::Error>(()) }); Self { sender: tx } } - pub async fn query(&self, query: String) -> Result> { + pub async fn query(&self, query: String) -> Result> { // Create a oneshot channel for the response let (response_tx, response_rx) = oneshot::channel(); @@ -67,3 +78,172 @@ impl SqliteDatabase { } } } + +fn create_in_memory_sqlite_db( + runtime_handle: Handle, + databases_store: Box, + database_id: String, + tables: &Vec, +) -> Result { + async fn fetch_rows( + database_id: &str, + table: &DatabaseTable, + databases_store: Box, + ) -> Result<(DatabaseTable, Vec)> { + let (rows, _) = databases_store + .list_database_rows(&database_id, table.table_id(), None) + .await?; + + Ok((table.clone(), rows)) + } + + let time_get_rows_start = utils::now(); + + let rows = runtime_handle.block_on(async move { + try_join_all( + tables + .iter() + .map(|table| fetch_rows(&database_id, table, databases_store.clone())), + ) + .await + })?; + utils::done(&format!( + "DSSTRUCTSTAT - WORKER Finished retrieving rows: duration={}ms", + utils::now() - time_get_rows_start + )); + + let generate_create_table_sql_start = utils::now(); + let create_tables_sql: String = tables + .into_iter() + .filter_map(|t| match t.schema() { + Some(s) => { + if s.is_empty() { + None + } else { + Some(s.get_create_table_sql_string(t.name())) + } + } + None => None, + }) + .collect::>() + .join("\n"); + utils::done(&format!( + "DSSTRUCTSTAT - WORKER Finished generating create table SQL: duration={}ms", + utils::now() - generate_create_table_sql_start + )); + + let conn = Connection::open_in_memory().unwrap(); + + let create_tables_execute_start = utils::now(); + conn.execute_batch(&create_tables_sql)?; + utils::done(&format!( + "DSSTRUCTSTAT - WORKER Finished creating tables: duration={}ms", + utils::now() - create_tables_execute_start + )); + + let insert_execute_start = utils::now(); + rows.iter() + .filter(|(_, rows)| !rows.is_empty()) + .map(|(table, rows)| { + if table.schema().is_none() { + Err(anyhow!("No schema found for table {}", table.name()))?; + } + let table_schema = table.schema().unwrap(); + let (sql, field_names) = table_schema.get_insert_sql(table.name()); + let mut stmt = conn.prepare(&sql)?; + + rows.par_iter() + .map(|r| match table_schema.get_insert_params(&field_names, r) { + Ok(params) => Ok(params_from_iter(params)), + Err(e) => Err(anyhow!( + "Error getting insert params for row {}: {}", + r.row_id(), + e + )), + }) + .collect::>>()? + .into_iter() + .map(|params| match stmt.execute(params) { + Ok(_) => Ok(()), + Err(e) => Err(anyhow!("Error inserting row: {}", e)), + }) + .collect::>>() + }) + .collect::>>()?; + utils::done(&format!( + "DSSTRUCTSTAT - WORKER Finished inserting rows: duration={}ms", + utils::now() - insert_execute_start + )); + + Ok(conn) +} + +fn execute_query_on_conn(conn: &Connection, query: String) -> Result> { + let time_query_start = utils::now(); + // Execute the query and collect results + let mut stmt = conn.prepare(&query).unwrap(); + // copy the column names into a vector of strings + let column_names = stmt + .column_names() + .into_iter() + .map(|x| x.to_string()) + .collect::>(); + + let result_rows = stmt + .query_and_then([], |row| { + column_names + .iter() + .enumerate() + .map(|(i, column_name)| { + Ok(( + column_name.clone(), + match row.get(i) { + Err(e) => Err(anyhow!( + "Failed to retrieve value for column {}: {}", + column_name, + e + )), + Ok(v) => match v { + rusqlite::types::Value::Integer(i) => { + Ok(serde_json::Value::Number(i.into())) + } + rusqlite::types::Value::Real(f) => { + match serde_json::Number::from_f64(f) { + Some(n) => Ok(serde_json::Value::Number(n)), + None => Err(anyhow!( + "Invalid float value for column {}", + column_name + )), + } + } + rusqlite::types::Value::Text(t) => { + Ok(serde_json::Value::String(t.clone())) + } + rusqlite::types::Value::Blob(b) => { + match String::from_utf8(b.clone()) { + Err(_) => Err(anyhow!( + "Invalid UTF-8 sequence for column {}", + column_name + )), + Ok(s) => Ok(serde_json::Value::String(s)), + } + } + rusqlite::types::Value::Null => Ok(serde_json::Value::Null), + }, + }?, + )) + }) + .collect::>() + })? + .collect::>>()? + .into_par_iter() + .map(|value| DatabaseResult { value }) + .collect::>(); + + utils::done(&format!( + "DSSTRUCTSTAT - WORKER Finished executing user query: duration={}ms", + utils::now() - time_query_start + )); + + Ok(result_rows) +} diff --git a/core/src/sqlite_workers/sqlite_workers.rs b/core/src/sqlite_workers/sqlite_workers.rs deleted file mode 100644 index 27d16a50344e..000000000000 --- a/core/src/sqlite_workers/sqlite_workers.rs +++ /dev/null @@ -1,43 +0,0 @@ -use anyhow::{anyhow, Result}; - -use crate::utils; - -pub const HEARTBEAT_INTERVAL_MS: u64 = 3_000; - -pub struct SqliteWorker { - last_heartbeat: u64, - pod_name: String, -} - -impl SqliteWorker { - pub fn new(pod_name: String, last_heartbeat: u64) -> Self { - Self { - last_heartbeat: last_heartbeat, - pod_name, - } - } - - pub fn is_alive(&self) -> bool { - let now = utils::now(); - let elapsed = now - self.last_heartbeat; - - elapsed < HEARTBEAT_INTERVAL_MS - } - - pub fn url(&self) -> Result { - let cluster_namespace = match std::env::var("CLUSTER_NAMESPACE") { - Ok(n) => n, - Err(_) => Err(anyhow!("CLUSTER_NAMESPACE env var not set"))?, - }; - let core_sqlite_headless_service_name = - match std::env::var("CORE_SQLITE_HEADLESS_SERVICE_NAME") { - Ok(s) => s, - Err(_) => Err(anyhow!("CORE_SQLITE_HEADLESS_SERVICE_NAME env var not set"))?, - }; - - Ok(format!( - "http://{}.{}.{}.svc.cluster.local", - self.pod_name, core_sqlite_headless_service_name, cluster_namespace - )) - } -} diff --git a/core/src/sqlite_workers/store.rs b/core/src/sqlite_workers/store.rs new file mode 100644 index 000000000000..fd1a24da4640 --- /dev/null +++ b/core/src/sqlite_workers/store.rs @@ -0,0 +1,232 @@ +use anyhow::Result; +use async_trait::async_trait; +use bb8::Pool; +use bb8_postgres::PostgresConnectionManager; +use serde_json::Value; +use tokio_postgres::{types::ToSql, NoTls}; + +use crate::{databases::database::DatabaseRow, utils}; + +#[async_trait] +pub trait DatabasesStore { + async fn init(&self) -> Result<()>; + async fn load_database_row( + &self, + database_id: &str, + table_id: &str, + row_id: &str, + ) -> Result>; + async fn list_database_rows( + &self, + database_id: &str, + table_id: &str, + limit_offset: Option<(usize, usize)>, + ) -> Result<(Vec, usize)>; + async fn batch_upsert_database_rows( + &self, + database_id: &str, + table_id: &str, + rows: &Vec, + truncate: bool, + ) -> Result<()>; + + fn clone_box(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Self { + self.clone_box() + } +} + +#[derive(Clone)] +pub struct PostgresDatabasesStore { + pool: Pool>, +} + +impl PostgresDatabasesStore { + pub async fn new(db_uri: &str) -> Result { + let manager = PostgresConnectionManager::new_from_stringlike(db_uri, NoTls)?; + let pool = Pool::builder().max_size(16).build(manager).await?; + Ok(Self { pool }) + } +} + +#[async_trait] +impl DatabasesStore for PostgresDatabasesStore { + async fn init(&self) -> Result<()> { + let conn = self.pool.get().await?; + for table in POSTGRES_TABLES { + conn.execute(table, &[]).await?; + } + for index in SQL_INDEXES { + conn.execute(index, &[]).await?; + } + Ok(()) + } + + async fn load_database_row( + &self, + database_id: &str, + table_id: &str, + row_id: &str, + ) -> Result> { + let pool = self.pool.clone(); + let c = pool.get().await?; + + let stmt = c + .prepare( + "SELECT created, row_id, content + FROM databases_rows + WHERE database_id = $1 AND table_id = $2 AND row_id = $3 + LIMIT 1", + ) + .await?; + + let r = c.query(&stmt, &[&database_id, &table_id, &row_id]).await?; + + let d: Option<(i64, String, String)> = match r.len() { + 0 => None, + 1 => Some((r[0].get(0), r[0].get(1), r[0].get(2))), + _ => unreachable!(), + }; + + match d { + None => Ok(None), + Some((_, row_id, data)) => { + Ok(Some(DatabaseRow::new(row_id, serde_json::from_str(&data)?))) + } + } + } + + async fn list_database_rows( + &self, + database_id: &str, + table_id: &str, + limit_offset: Option<(usize, usize)>, + ) -> Result<(Vec, usize)> { + let pool = self.pool.clone(); + let c = pool.get().await?; + + let mut params: Vec<&(dyn ToSql + Sync)> = vec![&database_id, &table_id]; + let mut query = "SELECT created, row_id, content + FROM databases_rows + WHERE database_id = $1 AND table_id = $2 + ORDER BY created DESC" + .to_string(); + + let limit_i64: i64; + let offset_i64: i64; + if let Some((limit, offset)) = limit_offset { + query.push_str(" LIMIT $2 OFFSET $3"); + limit_i64 = limit as i64; + offset_i64 = offset as i64; + params.push(&limit_i64); + params.push(&(offset_i64)); + } + + let rows = c.query(&query, ¶ms).await?; + + let rows: Vec = rows + .iter() + .map(|row| { + let row_id: String = row.get(1); + let data: String = row.get(2); + let content: Value = serde_json::from_str(&data)?; + Ok(DatabaseRow::new(row_id, content)) + }) + .collect::>>()?; + + let total = match limit_offset { + None => rows.len(), + Some(_) => { + let t: i64 = c + .query_one( + "SELECT COUNT(*) + FROM databases_rows + WHERE database_id = $1 AND table_id = $2", + &[&database_id, &table_id], + ) + .await? + .get(0); + t as usize + } + }; + + Ok((rows, total)) + } + + async fn batch_upsert_database_rows( + &self, + database_id: &str, + table_id: &str, + rows: &Vec, + truncate: bool, + ) -> Result<()> { + let pool = self.pool.clone(); + let mut c = pool.get().await?; + // Start transaction. + let c = c.transaction().await?; + + // Truncate table if required. + if truncate { + let stmt = c + .prepare( + "DELETE FROM databases_rows + WHERE database_id = $1 AND table_id = $2", + ) + .await?; + c.execute(&stmt, &[&database_id, &table_id]).await?; + } + + // Prepare insertion/updation statement. + let stmt = c + .prepare( + "INSERT INTO databases_rows + (id, database_id, table_id, row_id, created, content) + VALUES (DEFAULT, $1, $2, $3, $4, $5) + ON CONFLICT (database_id, table_id, row_id) DO UPDATE + SET content = EXCLUDED.content", + ) + .await?; + + for row in rows { + c.execute( + &stmt, + &[ + &database_id, + &table_id, + &row.row_id(), + &(utils::now() as i64), + &row.content().to_string(), + ], + ) + .await?; + } + + c.commit().await?; + + Ok(()) + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +pub const POSTGRES_TABLES: [&'static str; 1] = [ + // + "CREATE TABLE IF NOT EXISTS databases_rows ( + id BIGSERIAL PRIMARY KEY, + created BIGINT NOT NULL, + database_id TEXT NOT NULL, -- unique id of the database (globally) + table_id TEXT NOT NULL, -- unique within database + row_id TEXT NOT NULL, -- unique within table + content TEXT NOT NULL -- json + );", +]; + +pub const SQL_INDEXES: [&'static str; 2] = [ + "CREATE UNIQUE INDEX IF NOT EXISTS databases_rows_unique ON databases_rows (row_id, table_id, database_id);", + "CREATE INDEX IF NOT EXISTS databases_rows_database_id_table_id ON databases_rows (database_id, table_id);", +]; diff --git a/core/src/stores/postgres.rs b/core/src/stores/postgres.rs index c02bafc1a675..b6ec93d6a5fe 100644 --- a/core/src/stores/postgres.rs +++ b/core/src/stores/postgres.rs @@ -3,7 +3,7 @@ use crate::consts::DATA_SOURCE_DOCUMENT_SYSTEM_TAG_PREFIX; use crate::data_sources::data_source::{ DataSource, DataSourceConfig, Document, DocumentVersion, SearchFilter, }; -use crate::databases::database::{Database, DatabaseRow, DatabaseTable}; +use crate::databases::database::{Database, DatabaseTable}; use crate::databases::table_schema::TableSchema; use crate::dataset::Dataset; use crate::http::request::{HttpRequest, HttpResponse}; @@ -11,7 +11,7 @@ use crate::project::Project; use crate::providers::embedder::{EmbedderRequest, EmbedderVector}; use crate::providers::llm::{LLMChatGeneration, LLMChatRequest, LLMGeneration, LLMRequest}; use crate::run::{BlockExecution, Run, RunConfig, RunStatus, RunType}; -use crate::sqlite_workers::sqlite_workers::SqliteWorker; +use crate::sqlite_workers::client::SqliteWorker; use crate::stores::store::{Store, POSTGRES_TABLES, SQL_FUNCTIONS, SQL_INDEXES}; use crate::utils; use anyhow::{anyhow, Result}; @@ -19,7 +19,10 @@ use async_trait::async_trait; use bb8::Pool; use bb8_postgres::PostgresConnectionManager; use serde_json::Value; +use std::collections::hash_map::DefaultHasher; use std::collections::{HashMap, HashSet}; +use std::hash::Hash; +use std::hash::Hasher; use std::str::FromStr; use tokio_postgres::types::ToSql; use tokio_postgres::NoTls; @@ -1993,6 +1996,125 @@ impl Store for PostgresStore { Ok(databases) } + async fn assign_live_sqlite_worker_to_database( + &self, + project: &Project, + data_source_id: &str, + database_id: &str, + ttl: u64, + ) -> Result { + let project_id = project.project_id(); + let pool = self.pool.clone(); + let mut c = pool.get().await?; + + let mut hasher = DefaultHasher::new(); + format!( + "databases-{}-{}-{}", + project_id, data_source_id, database_id + ) + .hash(&mut hasher); + let lock_key = hasher.finish() as i64; + + let tx = c.transaction().await?; + + // Acquire a transaction-level advisory lock on the database. + tx.execute("SELECT pg_advisory_xact_lock($1)", &[&(lock_key as i64)]) + .await?; + + // Get the data source row id. + let stmt = tx + .prepare( + "SELECT id FROM data_sources WHERE project = $1 AND data_source_id = $2 LIMIT 1", + ) + .await?; + let r = tx.query(&stmt, &[&project_id, &data_source_id]).await?; + let data_source_row_id: i64 = match r.len() { + 0 => Err(anyhow!("Unknown DataSource: {}", data_source_id))?, + 1 => r[0].get(0), + _ => unreachable!(), + }; + + // Check if there is already an assigned live worker. + let stmt = tx + .prepare( + "SELECT pod_name, last_heartbeat + FROM sqlite_workers + WHERE id IN ( + SELECT sqlite_worker + FROM databases + WHERE data_source = $1 AND database_id = $2 + ) AND last_heartbeat > $3 LIMIT 1", + ) + .await?; + let r = tx + .query( + &stmt, + &[ + &data_source_row_id, + &database_id, + &((utils::now() - ttl) as i64), + ], + ) + .await?; + + let worker: Option = match r.len() { + 0 => None, + 1 => { + let (pod_name, last_heartbeat): (String, i64) = (r[0].get(0), r[0].get(1)); + Some(SqliteWorker::new(pod_name, last_heartbeat as u64)) + } + _ => unreachable!(), + }; + + if worker.is_some() { + // There is already an assigned worker. + // We can release the lock and return the worker. + tx.commit().await?; + return Ok(worker.unwrap()); + } + + // Pick a random live worker. + let stmt = tx + .prepare( + "SELECT id, pod_name, last_heartbeat + FROM sqlite_workers + WHERE last_heartbeat > $1 ORDER BY RANDOM() LIMIT 1", + ) + .await?; + let r = tx.query(&stmt, &[&((utils::now() - ttl) as i64)]).await?; + + match r.len() { + 0 => Err(anyhow!("No live workers found"))?, + 1 => { + let (sqlite_worker_row_id, pod_name, last_heartbeat): (i64, String, i64) = + (r[0].get(0), r[0].get(1), r[0].get(2)); + + // Update the database row to assign the worker. + let stmt = tx + .prepare( + "UPDATE databases SET sqlite_worker = $1 \ + WHERE data_source = $2 AND database_id = $3", + ) + .await?; + tx.execute( + &stmt, + &[ + &sqlite_worker_row_id, + &data_source_row_id, + &database_id.to_string(), + ], + ) + .await?; + + // Release the lock. + tx.commit().await?; + + Ok(SqliteWorker::new(pod_name, last_heartbeat as u64)) + } + _ => unreachable!(), + } + } + async fn upsert_database_table( &self, project: &Project, @@ -2311,253 +2433,6 @@ impl Store for PostgresStore { Ok((tables, total)) } - async fn load_database_row( - &self, - project: &Project, - data_source_id: &str, - database_id: &str, - table_id: &str, - row_id: &str, - ) -> Result> { - let project_id = project.project_id(); - let data_source_id = data_source_id.to_string(); - let database_id = database_id.to_string(); - let table_id = table_id.to_string(); - let row_id = row_id.to_string(); - - let pool = self.pool.clone(); - let c = pool.get().await?; - - let stmt = c - .prepare( - "SELECT created, row_id, content FROM databases_rows \ - WHERE database_table IN ( - SELECT id FROM databases_tables WHERE database IN ( - SELECT id FROM databases WHERE data_source IN ( - SELECT id FROM data_sources WHERE project = $1 AND data_source_id = $2 - ) AND database_id = $3 - ) AND table_id = $4 - ) \ - AND row_id = $5 LIMIT 1", - ) - .await?; - let r = c - .query( - &stmt, - &[ - &project_id, - &data_source_id, - &database_id, - &table_id, - &row_id, - ], - ) - .await?; - - let d: Option<(i64, String, String)> = match r.len() { - 0 => None, - 1 => Some((r[0].get(0), r[0].get(1), r[0].get(2))), - _ => unreachable!(), - }; - - match d { - None => Ok(None), - Some((_, row_id, data)) => Ok(Some(DatabaseRow::new(row_id, Value::from_str(&data)?))), - } - } - - async fn list_database_rows( - &self, - project: &Project, - data_source_id: &str, - database_id: &str, - table_id: &str, - limit_offset: Option<(usize, usize)>, - ) -> Result<(Vec, usize)> { - let project_id = project.project_id(); - let data_source_id = data_source_id.to_string(); - let database_id = database_id.to_string(); - - let pool = self.pool.clone(); - let c = pool.get().await?; - - let r = c - .query( - "SELECT id FROM data_sources WHERE project = $1 AND data_source_id = $2 LIMIT 1", - &[&project_id, &data_source_id], - ) - .await?; - - let data_source_row_id: i64 = match r.len() { - 0 => Err(anyhow!("Unknown DataSource: {}", data_source_id))?, - 1 => r[0].get(0), - _ => unreachable!(), - }; - - let r = c - .query( - "SELECT id FROM databases WHERE data_source = $1 AND database_id = $2 LIMIT 1", - &[&data_source_row_id, &database_id], - ) - .await?; - - let database_row_id: i64 = match r.len() { - 0 => Err(anyhow!("Unknown Database: {}", database_id))?, - 1 => r[0].get(0), - _ => unreachable!(), - }; - - let r = c - .query( - "SELECT id FROM databases_tables WHERE database = $1 AND table_id = $2 LIMIT 1", - &[&database_row_id, &table_id], - ) - .await?; - - let table_row_id: i64 = match r.len() { - 0 => Err(anyhow!("Unknown Table: {}", table_id))?, - 1 => r[0].get(0), - _ => unreachable!(), - }; - - let mut params: Vec<&(dyn ToSql + Sync)> = vec![&table_row_id]; - let mut query = "SELECT created, row_id, content FROM databases_rows \ - WHERE database_table = $1 ORDER BY created DESC" - .to_string(); - - let limit_i64: i64; - let offset_i64: i64; - if let Some((limit, offset)) = limit_offset { - query.push_str(" LIMIT $2 OFFSET $3"); - limit_i64 = limit as i64; - offset_i64 = offset as i64; - params.push(&limit_i64); - params.push(&(offset_i64)); - } - - let rows = c.query(&query, ¶ms).await?; - - let rows: Vec = rows - .iter() - .map(|row| { - let row_id: String = row.get(1); - let data: String = row.get(2); - let content: Value = serde_json::from_str(&data)?; - Ok(DatabaseRow::new(row_id, content)) - }) - .collect::>>()?; - - let total = match limit_offset { - None => rows.len(), - Some(_) => { - let t: i64 = c - .query_one( - "SELECT COUNT(*) FROM databases_rows WHERE database_table = $1", - &[&table_row_id], - ) - .await? - .get(0); - t as usize - } - }; - - Ok((rows, total)) - } - - async fn batch_upsert_database_rows( - &self, - project: &Project, - data_source_id: &str, - database_id: &str, - table_id: &str, - rows: &Vec, - truncate: bool, - ) -> Result<()> { - let project_id = project.project_id(); - let data_source_id = data_source_id.to_string(); - let database_id = database_id.to_string(); - let table_id = table_id.to_string(); - - let pool = self.pool.clone(); - let mut c = pool.get().await?; - - // Get the data source row id. - let stmt = c - .prepare( - "SELECT id FROM data_sources WHERE project = $1 AND data_source_id = $2 LIMIT 1", - ) - .await?; - let r = c.query(&stmt, &[&project_id, &data_source_id]).await?; - let data_source_row_id: i64 = match r.len() { - 0 => Err(anyhow!("Unknown DataSource: {}", data_source_id))?, - 1 => r[0].get(0), - _ => unreachable!(), - }; - - // Get the database row id. - let stmt = c - .prepare("SELECT id FROM databases WHERE data_source = $1 AND database_id = $2 LIMIT 1") - .await?; - let r = c.query(&stmt, &[&data_source_row_id, &database_id]).await?; - let database_row_id: i64 = match r.len() { - 0 => Err(anyhow!("Unknown Database: {}", database_id))?, - 1 => r[0].get(0), - _ => unreachable!(), - }; - - // Get the table row id. - let stmt = c - .prepare( - "SELECT id FROM databases_tables WHERE database = $1 AND table_id = $2 LIMIT 1", - ) - .await?; - let r = c.query(&stmt, &[&database_row_id, &table_id]).await?; - let table_row_id: i64 = match r.len() { - 0 => Err(anyhow!("Unknown Table: {}", table_id))?, - 1 => r[0].get(0), - _ => unreachable!(), - }; - - // Start transaction. - let c = c.transaction().await?; - - // Truncate table if required. - if truncate { - let stmt = c - .prepare("DELETE FROM databases_rows WHERE database_table = $1") - .await?; - c.execute(&stmt, &[&table_row_id]).await?; - } - - // Prepare insertion/updation statement. - let stmt = c - .prepare( - "INSERT INTO databases_rows \ - (id, database_table, created, row_id, content) \ - VALUES (DEFAULT, $1, $2, $3, $4) \ - ON CONFLICT (row_id, database_table) DO UPDATE \ - SET content = EXCLUDED.content", - ) - .await?; - - for row in rows { - c.execute( - &stmt, - &[ - &table_row_id, - &(utils::now() as i64), - &row.row_id().to_string(), - &row.content().to_string(), - ], - ) - .await?; - } - - c.commit().await?; - - Ok(()) - } - async fn delete_database( &self, project: &Project, @@ -2928,6 +2803,20 @@ impl Store for PostgresStore { let pool = self.pool.clone(); let c = pool.get().await?; + // Remove the worker from the databases. + let stmt = c + .prepare( + "UPDATE databases SET sqlite_worker = NULL \ + WHERE sqlite_worker IN ( + SELECT id + FROM sqlite_workers + WHERE pod_name = $1 + )", + ) + .await?; + c.execute(&stmt, &[&pod_name.to_string()]).await?; + + // Delete the worker. let stmt = c .prepare("DELETE FROM sqlite_workers WHERE pod_name = $1") .await?; @@ -2941,6 +2830,20 @@ impl Store for PostgresStore { let pool = self.pool.clone(); let c = pool.get().await?; + // Remove the dead workers from the databases. + let stmt = c + .prepare( + "UPDATE databases SET sqlite_worker = NULL \ + WHERE sqlite_worker IN ( + SELECT id + FROM sqlite_workers + WHERE last_heartbeat < $1 + )", + ) + .await?; + c.execute(&stmt, &[&(utils::now() as i64 - ttl as i64)]) + .await?; + let stmt = c .prepare("DELETE FROM sqlite_workers WHERE last_heartbeat < $1") .await?; diff --git a/core/src/stores/store.rs b/core/src/stores/store.rs index bf0b215ff6a7..b9da5f00d3a0 100644 --- a/core/src/stores/store.rs +++ b/core/src/stores/store.rs @@ -2,7 +2,7 @@ use crate::blocks::block::BlockType; use crate::data_sources::data_source::{ DataSource, DataSourceConfig, Document, DocumentVersion, SearchFilter, }; -use crate::databases::database::{Database, DatabaseRow, DatabaseTable}; +use crate::databases::database::{Database, DatabaseTable}; use crate::databases::table_schema::TableSchema; use crate::dataset::Dataset; use crate::http::request::{HttpRequest, HttpResponse}; @@ -10,7 +10,7 @@ use crate::project::Project; use crate::providers::embedder::{EmbedderRequest, EmbedderVector}; use crate::providers::llm::{LLMChatGeneration, LLMChatRequest, LLMGeneration, LLMRequest}; use crate::run::{Run, RunStatus, RunType}; -use crate::sqlite_workers::sqlite_workers::SqliteWorker; +use crate::sqlite_workers::client::SqliteWorker; use anyhow::Result; use async_trait::async_trait; @@ -178,6 +178,13 @@ pub trait Store { data_source_id: &str, limit_offset: Option<(usize, usize)>, ) -> Result>; + async fn assign_live_sqlite_worker_to_database( + &self, + project: &Project, + data_source_id: &str, + database_id: &str, + ttl: u64, + ) -> Result; async fn upsert_database_table( &self, project: &Project, @@ -209,31 +216,6 @@ pub trait Store { database_id: &str, limit_offset: Option<(usize, usize)>, ) -> Result<(Vec, usize)>; - async fn batch_upsert_database_rows( - &self, - project: &Project, - data_source_id: &str, - database_id: &str, - table_id: &str, - rows: &Vec, - truncate: bool, - ) -> Result<()>; - async fn load_database_row( - &self, - project: &Project, - data_source_id: &str, - database_id: &str, - table_id: &str, - row_id: &str, - ) -> Result>; - async fn list_database_rows( - &self, - project: &Project, - data_source_id: &str, - database_id: &str, - table_id: &str, - limit_offset: Option<(usize, usize)>, - ) -> Result<(Vec, usize)>; async fn delete_database( &self, project: &Project, @@ -308,7 +290,7 @@ impl Clone for Box { } } -pub const POSTGRES_TABLES: [&'static str; 15] = [ +pub const POSTGRES_TABLES: [&'static str; 14] = [ "-- projects CREATE TABLE IF NOT EXISTS projects ( id BIGSERIAL PRIMARY KEY @@ -413,14 +395,23 @@ pub const POSTGRES_TABLES: [&'static str; 15] = [ status TEXT NOT NULL, FOREIGN KEY(data_source) REFERENCES data_sources(id) );", - "-- database - CREATE TABLE IF NOT EXISTS databases ( + "-- SQLite workers + CREATE TABLE IF NOT EXISTS sqlite_workers ( id BIGSERIAL PRIMARY KEY, created BIGINT NOT NULL, - data_source BIGINT NOT NULL, - database_id TEXT NOT NULL, -- unique within data source. Used as the external id. - name TEXT NOT NULL, -- unique within data source - FOREIGN KEY(data_source) REFERENCES data_sources(id) + pod_name TEXT NOT NULL, + last_heartbeat BIGINT NOT NULL + );", + "-- database + CREATE TABLE IF NOT EXISTS databases ( + id BIGSERIAL PRIMARY KEY, + created BIGINT NOT NULL, + data_source BIGINT NOT NULL, + database_id TEXT NOT NULL, -- unique within data source. Used as the external id. + name TEXT NOT NULL, -- unique within data source + sqlite_worker BIGINT, + FOREIGN KEY(data_source) REFERENCES data_sources(id), + FOREIGN KEY(sqlite_worker) REFERENCES sqlite_workers(id) );", "-- databases tables CREATE TABLE IF NOT EXISTS databases_tables ( @@ -433,25 +424,9 @@ pub const POSTGRES_TABLES: [&'static str; 15] = [ schema TEXT, -- json, kept up-to-date automatically with the last insert FOREIGN KEY(database) REFERENCES databases(id) );", - "-- databases row - CREATE TABLE IF NOT EXISTS databases_rows ( - id BIGSERIAL PRIMARY KEY, - created BIGINT NOT NULL, - database_table BIGINT NOT NULL, - content TEXT NOT NULL, -- json - row_id TEXT NOT NULL, -- unique within table - FOREIGN KEY(database_table) REFERENCES databases_tables(id) - );", - "-- SQLite workers - CREATE TABLE IF NOT EXISTS sqlite_workers ( - id BIGSERIAL PRIMARY KEY, - created BIGINT NOT NULL, - pod_name TEXT NOT NULL, - last_heartbeat BIGINT NOT NULL - );", ]; -pub const SQL_INDEXES: [&'static str; 24] = [ +pub const SQL_INDEXES: [&'static str; 23] = [ "CREATE INDEX IF NOT EXISTS idx_specifications_project_created ON specifications (project, created);", "CREATE INDEX IF NOT EXISTS @@ -502,8 +477,6 @@ pub const SQL_INDEXES: [&'static str; 24] = [ idx_databases_tables_table_id_database ON databases_tables (table_id, database);", "CREATE UNIQUE INDEX IF NOT EXISTS idx_databases_tables_database_table_name ON databases_tables (database, name);", - "CREATE UNIQUE INDEX IF NOT EXISTS - idx_databases_rows_row_id_database_table ON databases_rows (row_id, database_table);", "CREATE UNIQUE INDEX IF NOT EXISTS idx_sqlite_workers_pod_name ON sqlite_workers (pod_name);", ]; diff --git a/k8s/deployments/core-sqlite-worker-deployment.yaml b/k8s/deployments/core-sqlite-worker-deployment.yaml index 5caac48689df..7a598082b43c 100644 --- a/k8s/deployments/core-sqlite-worker-deployment.yaml +++ b/k8s/deployments/core-sqlite-worker-deployment.yaml @@ -34,9 +34,8 @@ spec: envFrom: - configMapRef: name: core-sqlite-worker-config - # TODO: Uncomment this when we have secrets - # - secretRef: - # name: core-secrets + - secretRef: + name: core-sqlite-worker-secrets env: - name: DD_AGENT_HOST valueFrom: