Skip to content

Commit

Permalink
[core] feature(vault): Filter tables in APIs (#6725)
Browse files Browse the repository at this point in the history
  • Loading branch information
tdraier authored and albandum committed Aug 28, 2024
1 parent 6fdcd10 commit 0285622
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 73 deletions.
230 changes: 160 additions & 70 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use axum::{
routing::{delete, get, patch, post},
Router,
};

use dust::{
api_keys::validate_api_key,
app,
Expand All @@ -25,7 +26,7 @@ use dust::{
project,
providers::provider::{provider, ProviderID},
run,
search_filter::SearchFilter,
search_filter::{Filterable, SearchFilter},
secondary_api::forward_middleware,
sqlite_workers::client::{self, HEARTBEAT_INTERVAL_MS},
stores::{postgres, store},
Expand Down Expand Up @@ -1990,10 +1991,34 @@ async fn tables_upsert(
}
}

/// Retrieve table from a data source.
#[derive(serde::Deserialize)]
struct TableRetrieveQuery {
view_filter: Option<String>, // Parsed as JSON.
}

async fn tables_retrieve(
Path((project_id, data_source_id, table_id)): Path<(i64, String, String)>,
State(state): State<Arc<APIState>>,
Query(query): Query<TableRetrieveQuery>,
) -> (StatusCode, Json<APIResponse>) {
let view_filter: Option<SearchFilter> = match query
.view_filter
.as_ref()
.and_then(|f| Some(serde_json::from_str(f)))
{
Some(Ok(f)) => Some(f),
None => None,
Some(Err(e)) => {
return error_response(
StatusCode::BAD_REQUEST,
"invalid_view_filter",
"Failed to parse view_filter query parameter",
Some(e.into()),
)
}
};

let project = project::Project::new_from_id(project_id);

match state
Expand All @@ -2007,7 +2032,7 @@ async fn tables_retrieve(
"Failed to retrieve table",
Some(e),
),
Ok(table) => match table {
Ok(table) => match table.filter(|table| table.match_filter(&view_filter)) {
None => error_response(
StatusCode::NOT_FOUND,
"table_not_found",
Expand All @@ -2030,8 +2055,26 @@ async fn tables_retrieve(
async fn tables_list(
Path((project_id, data_source_id)): Path<(i64, String)>,
State(state): State<Arc<APIState>>,
Query(query): Query<TableRetrieveQuery>,
) -> (StatusCode, Json<APIResponse>) {
let project = project::Project::new_from_id(project_id);
let view_filter: Option<SearchFilter> = match query
.view_filter
.as_ref()
.and_then(|f| Some(serde_json::from_str(f)))
{
Some(Ok(f)) => Some(f),
None => None,
Some(Err(e)) => {
return error_response(
StatusCode::BAD_REQUEST,
"invalid_view_filter",
"Failed to parse view_filter query parameter",
Some(e.into()),
)
}
};

match state
.store
.list_tables(&project, &data_source_id, None)
Expand All @@ -2048,7 +2091,7 @@ async fn tables_list(
Json(APIResponse {
error: None,
response: Some(json!({
"tables": tables,
"tables": tables.into_iter().filter(|table| table.match_filter(&view_filter)).collect::<Vec<Table>>(),
})),
}),
),
Expand Down Expand Up @@ -2218,8 +2261,25 @@ async fn tables_rows_upsert(
async fn tables_rows_retrieve(
Path((project_id, data_source_id, table_id, row_id)): Path<(i64, String, String, String)>,
State(state): State<Arc<APIState>>,
Query(query): Query<TableRetrieveQuery>,
) -> (StatusCode, Json<APIResponse>) {
let project = project::Project::new_from_id(project_id);
let view_filter: Option<SearchFilter> = match query
.view_filter
.as_ref()
.and_then(|f| Some(serde_json::from_str(f)))
{
Some(Ok(f)) => Some(f),
None => None,
Some(Err(e)) => {
return error_response(
StatusCode::BAD_REQUEST,
"invalid_view_filter",
"Failed to parse view_filter query parameter",
Some(e.into()),
)
}
};

match state
.store
Expand All @@ -2234,39 +2294,41 @@ async fn tables_rows_retrieve(
Some(e),
)
}
Ok(None) => {
return error_response(
StatusCode::NOT_FOUND,
"table_not_found",
&format!("No table found for id `{}`", table_id),
None,
)
}
Ok(Some(table)) => match table
.retrieve_row(state.databases_store.clone(), &row_id)
.await
{
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to load row",
Some(e),
),
Ok(None) => error_response(
StatusCode::NOT_FOUND,
"table_row_not_found",
&format!("No table row found for id `{}`", row_id),
None,
),
Ok(Some(row)) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"row": row,
})),
}),
),
Ok(table) => match table.filter(|table| table.match_filter(&view_filter)) {
None => {
return error_response(
StatusCode::NOT_FOUND,
"table_not_found",
&format!("No table found for id `{}`", table_id),
None,
)
}
Some(table) => match table
.retrieve_row(state.databases_store.clone(), &row_id)
.await
{
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to load row",
Some(e),
),
Ok(None) => error_response(
StatusCode::NOT_FOUND,
"table_row_not_found",
&format!("No table row found for id `{}`", row_id),
None,
),
Ok(Some(row)) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"row": row,
})),
}),
),
},
},
}
}
Expand Down Expand Up @@ -2329,6 +2391,7 @@ async fn tables_rows_delete(
struct DatabasesRowsListQuery {
offset: usize,
limit: usize,
view_filter: Option<String>,
}

async fn tables_rows_list(
Expand All @@ -2337,6 +2400,22 @@ async fn tables_rows_list(
Query(query): Query<DatabasesRowsListQuery>,
) -> (StatusCode, Json<APIResponse>) {
let project = project::Project::new_from_id(project_id);
let view_filter: Option<SearchFilter> = match query
.view_filter
.as_ref()
.and_then(|f| Some(serde_json::from_str(f)))
{
Some(Ok(f)) => Some(f),
None => None,
Some(Err(e)) => {
return error_response(
StatusCode::BAD_REQUEST,
"invalid_view_filter",
"Failed to parse view_filter query parameter",
Some(e.into()),
)
}
};

match state
.store
Expand All @@ -2351,39 +2430,41 @@ async fn tables_rows_list(
Some(e),
)
}
Ok(None) => {
return error_response(
StatusCode::NOT_FOUND,
"table_not_found",
&format!("No table found for id `{}`", table_id),
None,
)
}
Ok(Some(table)) => match table
.list_rows(
state.databases_store.clone(),
Some((query.limit, query.offset)),
)
.await
{
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to list rows",
Some(e),
),
Ok((rows, total)) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"offset": query.offset,
"limit": query.limit,
"total": total,
"rows": rows,
})),
}),
),
Ok(table) => match table.filter(|table| table.match_filter(&view_filter)) {
None => {
return error_response(
StatusCode::NOT_FOUND,
"table_not_found",
&format!("No table found for id `{}`", table_id),
None,
)
}
Some(table) => match table
.list_rows(
state.databases_store.clone(),
Some((query.limit, query.offset)),
)
.await
{
Err(e) => error_response(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_server_error",
"Failed to list rows",
Some(e),
),
Ok((rows, total)) => (
StatusCode::OK,
Json(APIResponse {
error: None,
response: Some(json!({
"offset": query.offset,
"limit": query.limit,
"total": total,
"rows": rows,
})),
}),
),
},
},
}
}
Expand All @@ -2392,6 +2473,7 @@ async fn tables_rows_list(
struct DatabaseQueryRunPayload {
query: String,
tables: Vec<(i64, String, String)>,
view_filter: Option<SearchFilter>,
}

async fn databases_query_run(
Expand All @@ -2418,7 +2500,15 @@ async fn databases_query_run(
),
Ok(tables) => {
// Check that all tables exist.
match tables.into_iter().collect::<Option<Vec<Table>>>() {
match tables
.into_iter()
.filter(|table| {
table
.as_ref()
.map_or(true, |t| t.match_filter(&payload.view_filter))
})
.collect::<Option<Vec<Table>>>()
{
None => {
return error_response(
StatusCode::NOT_FOUND,
Expand Down
Loading

0 comments on commit 0285622

Please sign in to comment.