diff --git a/core/Cargo.toml b/core/Cargo.toml index 7597ddb565a3..4c0c1e4b905e 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -27,6 +27,10 @@ path = "bin/init_db.rs" name = "elasticsearch_create_index" path = "bin/elasticsearch/create_index.rs" +[[bin]] +name = "elasticsearch_backfill_index" +path = "bin/elasticsearch/backfill_index.rs" + [[bin]] name = "qdrant_create_collection" path = "bin/qdrant/create_collection.rs" diff --git a/core/bin/elasticsearch/backfill_index.rs b/core/bin/elasticsearch/backfill_index.rs new file mode 100644 index 000000000000..b3d2d75f9189 --- /dev/null +++ b/core/bin/elasticsearch/backfill_index.rs @@ -0,0 +1,137 @@ +use clap::Parser; +use dust::{ + data_sources::node::Node, + search_stores::search_store::ElasticsearchSearchStore, + stores::{postgres::PostgresStore, store::Store}, + utils::{self}, +}; +use elasticsearch::{http::request::JsonBody, indices::IndicesExistsParts, BulkParts}; +use http::StatusCode; +use serde_json::json; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + #[arg(long, help = "The version of the index")] + index_version: u32, + + #[arg(long, help = "Skip confirmation")] + skip_confirmation: bool, + + #[arg(long, help = "The cursor to start from", default_value = "0")] + start_cursor: i64, + + #[arg(long, help = "The batch size", default_value = "100")] + batch_size: i64, +} + +/* + * Backfills nodes index in Elasticsearch for core using the postgres table `data_sources_nodes` + * + * Usage: + * cargo run --bin elasticsearch_backfill_nodes_index -- --index-version [--skip-confirmation] [--start-cursor ] [--batch-size ] + * + */ +#[tokio::main] +async fn main() -> Result<(), Box> { + // parse args and env vars + let args = Args::parse(); + let index_name = "data_sources_nodes"; + let index_version = args.index_version; + let batch_size = args.batch_size; + let start_cursor = args.start_cursor; + + let url = std::env::var("ELASTICSEARCH_URL").expect("ELASTICSEARCH_URL must be set"); + let username = + std::env::var("ELASTICSEARCH_USERNAME").expect("ELASTICSEARCH_USERNAME must be set"); + let password = + std::env::var("ELASTICSEARCH_PASSWORD").expect("ELASTICSEARCH_PASSWORD must be set"); + + let region = std::env::var("DUST_REGION").expect("DUST_REGION must be set"); + + // create ES client + let search_store = ElasticsearchSearchStore::new(&url, &username, &password).await?; + + let index_fullname = format!("core.{}_{}", index_name, index_version); + + // err if index does not exist + let response = search_store + .client + .indices() + .exists(IndicesExistsParts::Index(&[index_fullname.as_str()])) + .send() + .await?; + + if response.status_code() != StatusCode::OK { + return Err(anyhow::anyhow!("Index does not exist").into()); + } + + if !args.skip_confirmation { + println!( + "Are you sure you want to backfill the index {} in region {}? (y/N)", + index_fullname, region + ); + let mut input = String::new(); + std::io::stdin().read_line(&mut input).unwrap(); + if input.trim() != "y" { + return Err(anyhow::anyhow!("Aborted").into()); + } + } + + let db_uri = std::env::var("CORE_DATABASE_READ_REPLICA_URI") + .expect("CORE_DATABASE_READ_REPLICA_URI must be set"); + let store = PostgresStore::new(&db_uri).await?; + // loop on all nodes in postgres using id as cursor, stopping when timestamp + // is greater than now + let mut next_cursor = start_cursor; + let now = utils::now(); + loop { + println!( + "Processing {} nodes, starting at id {}", + batch_size, next_cursor + ); + let (nodes, cursor) = + get_node_batch(next_cursor, batch_size, Box::new(store.clone())).await?; + if nodes.is_empty() || nodes.first().unwrap().timestamp > now { + break; + } + next_cursor = cursor; + + // + let nodes_body: Vec> = nodes + .into_iter() + .flat_map(|node| { + [ + json!({"index": {"_id": node.unique_id()}}).into(), + json!(node).into(), + ] + }) + .collect(); + search_store + .client + .bulk(BulkParts::Index(index_fullname.as_str())) + .body(nodes_body) + .send() + .await?; + } + + Ok(()) +} + +async fn get_node_batch( + next_cursor: i64, + batch_size: i64, + store: Box, +) -> Result<(Vec, i64), Box> { + let nodes = store + .list_data_source_nodes(next_cursor, batch_size) + .await?; + let last_node = nodes.last().cloned(); + match last_node { + Some((_, last_row_id, _)) => Ok(( + nodes.into_iter().map(|(node, _, _)| node).collect(), + last_row_id, + )), + None => Ok((vec![], 0)), + } +} diff --git a/core/bin/elasticsearch/create_index.rs b/core/bin/elasticsearch/create_index.rs index f9d672d78544..fc64d96dabfd 100644 --- a/core/bin/elasticsearch/create_index.rs +++ b/core/bin/elasticsearch/create_index.rs @@ -1,26 +1,10 @@ use std::collections::HashMap; -use clap::{Parser, ValueEnum}; +use clap::Parser; use dust::search_stores::search_store::ElasticsearchSearchStore; use elasticsearch::indices::{IndicesCreateParts, IndicesDeleteAliasParts, IndicesExistsParts}; use http::StatusCode; -#[derive(Parser, Debug, Clone, ValueEnum)] -enum Region { - Local, - #[clap(name = "us-central-1")] - UsCentral1, -} - -impl std::fmt::Display for Region { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Region::Local => write!(f, "local"), - Region::UsCentral1 => write!(f, "us-central-1"), - } - } -} - #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { diff --git a/core/src/stores/postgres.rs b/core/src/stores/postgres.rs index 589158efee0d..2d8475681fd2 100644 --- a/core/src/stores/postgres.rs +++ b/core/src/stores/postgres.rs @@ -3525,6 +3525,64 @@ impl Store for PostgresStore { } } + async fn list_data_source_nodes( + &self, + id_cursor: i64, + batch_size: i64, + ) -> Result> { + let pool = self.pool.clone(); + let c = pool.get().await?; + + let stmt = c + .prepare( + "SELECT dsn.timestamp, dsn.title, dsn.mime_type, dsn.parents, dsn.node_id, dsn.document, dsn.\"table\", dsn.folder, ds.data_source_id, ds.internal_id, dsn.id \ + FROM data_sources_nodes dsn JOIN data_sources ds ON dsn.data_source = ds.id \ + WHERE dsn.id > $1 ORDER BY dsn.id ASC LIMIT $2", + ) + .await?; + let rows = c.query(&stmt, &[&id_cursor, &batch_size]).await?; + + let nodes: Vec<(Node, i64, i64)> = rows + .iter() + .map(|row| { + let timestamp: i64 = row.get::<_, i64>(0); + let title: String = row.get::<_, String>(1); + let mime_type: String = row.get::<_, String>(2); + let parents: Vec = row.get::<_, Vec>(3); + let node_id: String = row.get::<_, String>(4); + let document_row_id = row.get::<_, Option>(5); + let table_row_id = row.get::<_, Option>(6); + let folder_row_id = row.get::<_, Option>(7); + let data_source_id: String = row.get::<_, String>(8); + let data_source_internal_id: String = row.get::<_, String>(9); + let (node_type, element_row_id) = + match (document_row_id, table_row_id, folder_row_id) { + (Some(id), None, None) => (NodeType::Document, id), + (None, Some(id), None) => (NodeType::Table, id), + (None, None, Some(id)) => (NodeType::Folder, id), + _ => unreachable!(), + }; + let row_id = row.get::<_, i64>(10); + ( + Node::new( + &data_source_id, + &data_source_internal_id, + &node_id, + node_type, + timestamp as u64, + &title, + &mime_type, + parents.get(1).cloned(), + parents, + ), + row_id, + element_row_id, + ) + }) + .collect::>(); + Ok(nodes) + } + async fn llm_cache_get( &self, project: &Project, diff --git a/core/src/stores/store.rs b/core/src/stores/store.rs index 109baa7aa01e..a187e6acad63 100644 --- a/core/src/stores/store.rs +++ b/core/src/stores/store.rs @@ -347,6 +347,13 @@ pub trait Store { data_source_id: &str, node_id: &str, ) -> Result>; + // returns a list of (node, row_id, element_row_id) + async fn list_data_source_nodes( + &self, + id_cursor: i64, + batch_size: i64, + ) -> Result>; + // LLM Cache async fn llm_cache_get( &self,