Skip to content

Commit

Permalink
[Groups] core: data source block view filter (#6525)
Browse files Browse the repository at this point in the history
* front: CoreAPISearchFilter initial plumbing

* core implementation of X-Dust-Group-Ids and retrieval of view_filter from registry
  • Loading branch information
spolu authored Jul 26, 2024
1 parent 955faf0 commit dc9948c
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 56 deletions.
18 changes: 18 additions & 0 deletions core/bin/dust_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,15 @@ async fn runs_create(
},
None => (),
};
match headers.get("X-Dust-Group-Ids") {
Some(v) => match v.to_str() {
Ok(v) => {
credentials.insert("DUST_GROUP_IDS".to_string(), v.to_string());
}
_ => (),
},
None => (),
};

match run_helper(project_id, payload.clone(), state.clone()).await {
Ok(app) => {
Expand Down Expand Up @@ -834,6 +843,15 @@ async fn runs_create_stream(
},
None => (),
};
match headers.get("X-Dust-Group-Ids") {
Some(v) => match v.to_str() {
Ok(v) => {
credentials.insert("DUST_GROUP_IDS".to_string(), v.to_string());
}
_ => (),
},
None => (),
};

// create unbounded channel to pass as stream to Sse::new
let (tx, mut rx) = unbounded_channel::<Value>();
Expand Down
19 changes: 7 additions & 12 deletions core/src/blocks/data_source.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::blocks::block::{
parse_pair, replace_variables_in_string, Block, BlockResult, BlockType, Env,
};
use crate::blocks::helpers::get_data_source_project;
use crate::blocks::helpers::get_data_source_project_and_view_filter;
use crate::data_sources::data_source::{Document, SearchFilter};
use crate::deno::js_executor::JSExecutor;
use crate::Rule;
Expand Down Expand Up @@ -75,18 +75,14 @@ impl DataSource {
async fn search_data_source(
&self,
env: &Env,
workspace_id: Option<String>,
workspace_id: String,
data_source_id: String,
top_k: usize,
filter: Option<SearchFilter>,
target_document_tokens: Option<usize>,
) -> Result<Vec<Document>> {
let data_source_project = match workspace_id {
Some(workspace_id) => {
get_data_source_project(&workspace_id, &data_source_id, env).await?
}
None => env.project.clone(),
};
let (data_source_project, view_filter) =
get_data_source_project_and_view_filter(&workspace_id, &data_source_id, env).await?;

let ds = match env
.store
Expand All @@ -110,8 +106,7 @@ impl DataSource {
Some(filter) => Some(filter.postprocess_for_data_source(&data_source_id)),
None => None,
},
// TODO(spolu): add in subsequent PR (data_source block view_filter support).
None,
view_filter,
self.full_text,
target_document_tokens,
)
Expand Down Expand Up @@ -200,8 +195,8 @@ impl Block for DataSource {
.iter()
.map(|v| {
let workspace_id = match v.get("workspace_id") {
Some(Value::String(p)) => Some(p.clone()),
_ => None,
Some(Value::String(p)) => p.clone(),
_ => Err(anyhow!(err_msg.clone()))?,
};
let data_source_id = match v.get("data_source_id") {
Some(Value::String(i)) => i.clone(),
Expand Down
12 changes: 7 additions & 5 deletions core/src/blocks/database_schema.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::helpers::get_data_source_project;
use super::helpers::get_data_source_project_and_view_filter;
use crate::blocks::block::{Block, BlockResult, BlockType, Env};
use crate::databases::database::{get_unique_table_names_for_database, Table};
use crate::Rule;
Expand Down Expand Up @@ -126,18 +126,20 @@ pub async fn load_tables_from_identifiers(
.collect::<Vec<_>>();

// Get a vec of the corresponding project ids for each (workspace_id, data_source_id) pair.
let project_ids = try_join_all(
let project_ids_view_filters = try_join_all(
data_source_identifiers
.iter()
.map(|(w, d)| get_data_source_project(w, d, env)),
.map(|(w, d)| get_data_source_project_and_view_filter(w, d, env)),
)
.await?;

// TODO(GROUPS_INFRA): enforce view_filter as returned above.

// Create a hashmap of (workspace_id, data_source_id) -> project_id.
let project_by_data_source = data_source_identifiers
.iter()
.zip(project_ids.iter())
.map(|((w, d), p)| ((*w, *d), p.clone()))
.zip(project_ids_view_filters.iter())
.map(|((w, d), p)| ((*w, *d), p.0.clone()))
.collect::<std::collections::HashMap<_, _>>();

let store = env.store.clone();
Expand Down
38 changes: 26 additions & 12 deletions core/src/blocks/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,40 @@
use super::block::Env;
use crate::project::Project;
use crate::{data_sources::data_source::SearchFilter, project::Project};
use anyhow::{anyhow, Result};
use hyper::body::Buf;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::io::prelude::*;
use url::Url;
use urlencoding::encode;

pub async fn get_data_source_project(
#[derive(Debug, Serialize, Deserialize, Clone)]
struct FrontRegistryPayload {
data_source_id: String,
project_id: i64,
view_filter: Option<SearchFilter>,
}

pub async fn get_data_source_project_and_view_filter(
workspace_id: &String,
data_source_id: &String,
env: &Env,
) -> Result<Project> {
) -> Result<(Project, Option<SearchFilter>)> {
let dust_workspace_id = match env.credentials.get("DUST_WORKSPACE_ID") {
None => Err(anyhow!(
"DUST_WORKSPACE_ID credentials missing, but `workspace_id` \
is set in `data_source` block config"
))?,
Some(v) => v.clone(),
};
let dust_group_ids = match env.credentials.get("DUST_GROUP_IDS") {
Some(v) => v.clone(),
// We default to the empty string if not set which will default to the workspace global
// group in front registry.
None => "".to_string(),
};

let registry_secret = match std::env::var("DUST_REGISTRY_SECRET") {
Ok(key) => key,
Err(_) => Err(anyhow!(
Expand All @@ -46,6 +61,7 @@ pub async fn get_data_source_project(
format!("Bearer {}", registry_secret.as_str()),
)
.header("X-Dust-Workspace-Id", dust_workspace_id)
.header("X-Dust-Group-Ids", dust_group_ids)
.send()
.await?;

Expand All @@ -65,16 +81,14 @@ pub async fn get_data_source_project(

let response_body = String::from_utf8_lossy(&b).into_owned();

let body = match serde_json::from_str::<serde_json::Value>(&response_body) {
Ok(body) => body,
// parse body into FrontRegistryPayload
let payload: FrontRegistryPayload = match serde_json::from_str(&response_body) {
Ok(payload) => payload,
Err(_) => Err(anyhow!("Failed to parse registry response"))?,
};

match body.get("project_id") {
Some(Value::Number(p)) => match p.as_i64() {
Some(p) => Ok(Project::new_from_id(p)),
None => Err(anyhow!("Failed to parse registry response")),
},
_ => Err(anyhow!("Failed to parse registry response")),
}
Ok((
Project::new_from_id(payload.project_id),
payload.view_filter,
))
}
16 changes: 15 additions & 1 deletion front/pages/api/registry/[type]/lookup.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { CoreAPISearchFilter } from "@dust-tt/types";
import type { NextApiRequest, NextApiResponse } from "next";

import { DataSource } from "@app/lib/models/data_source";
Expand All @@ -9,6 +10,7 @@ const { DUST_REGISTRY_SECRET } = process.env;
type LookupDataSourceResponseBody = {
project_id: number;
data_source_id: string;
view_filter: CoreAPISearchFilter | null;
};

/**
Expand Down Expand Up @@ -56,6 +58,10 @@ async function handler(
return;
}

// TODO(GROUPS_INFRA): Add x-dust-group-ids header retrieval
// - If not set default to the global workspace group
// - Enforce checks for access to data sources and data sources view below

const dustWorkspaceId = req.headers["x-dust-workspace-id"] as string;

switch (req.method) {
Expand Down Expand Up @@ -98,17 +104,25 @@ async function handler(
return;
}

// TODO(GROUPS_INFRA):
// - Implement view_filter return when a data source view is looked up.
// - If data_source_ids is of the form `dsv_...` then it's a data source view
// and we pull the view_filter to return it below
// - otherwise it's data source and the view_filter is null
// - Obviously this is where we check based on the x-dust-group-ids header that we
// have access to the data source or data source view

res.status(200).json({
project_id: parseInt(dataSource.dustAPIProjectId),
data_source_id: req.query.data_source_id,
view_filter: null,
});
return;

default:
res.status(405).end();
return;
}
return;

default:
res.status(405).end();
Expand Down
12 changes: 6 additions & 6 deletions front/pages/api/v1/w/[wId]/data_sources/[name]/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -246,16 +246,16 @@ async function handler(
target_document_tokens: query.target_document_tokens,
filter: {
tags: {
in: query.tags_in,
not: query.tags_not,
in: query.tags_in ?? null,
not: query.tags_not ?? null,
},
parents: {
in: query.parents_in,
not: query.parents_not,
in: query.parents_in ?? null,
not: query.parents_not ?? null,
},
timestamp: {
gt: query.timestamp_gt,
lt: query.timestamp_lt,
gt: query.timestamp_gt ?? null,
lt: query.timestamp_lt ?? null,
},
},
credentials,
Expand Down
12 changes: 6 additions & 6 deletions front/pages/api/w/[wId]/data_sources/[name]/search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,16 @@ export async function handleSearchDataSource({
target_document_tokens: searchQuery.target_document_tokens,
filter: {
tags: {
in: searchQuery.tags_in,
not: searchQuery.tags_not,
in: searchQuery.tags_in ?? null,
not: searchQuery.tags_not ?? null,
},
parents: {
in: searchQuery.parents_in,
not: searchQuery.parents_not,
in: searchQuery.parents_in ?? null,
not: searchQuery.parents_not ?? null,
},
timestamp: {
gt: searchQuery.timestamp_gt,
lt: searchQuery.timestamp_lt,
gt: searchQuery.timestamp_gt ?? null,
lt: searchQuery.timestamp_lt ?? null,
},
},
credentials: credentials,
Expand Down
36 changes: 22 additions & 14 deletions types/src/front/lib/core_api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,21 @@ export type CoreAPIQueryResult = {
value: Record<string, unknown>;
};

export type CoreAPISearchFilter = {
tags: {
in: string[] | null;
not: string[] | null;
} | null;
parents: {
in: string[] | null;
not: string[] | null;
} | null;
timestamp: {
gt: number | null;
lt: number | null;
} | null;
};

export class CoreAPI {
_url: string;
declare _logger: LoggerInterface;
Expand Down Expand Up @@ -277,6 +292,9 @@ export class CoreAPI {
credentials,
secrets,
}: CoreAPICreateRunParams): Promise<CoreAPIResponse<{ run: CoreAPIRun }>> {
// TODO(GROUPS_INFRA): use the auth as argument of that method instead of `runAsWorkspaceId`
// and pass both X-Dust-Workspace-Id and X-Dust-Group-Ids.

const response = await this._fetchWithError(
`${this._url}/projects/${encodeURIComponent(projectId)}/runs`,
{
Expand Down Expand Up @@ -318,6 +336,9 @@ export class CoreAPI {
dustRunId: Promise<string>;
}>
> {
// TODO(GROUPS_INFRA): use the auth as argument of that method instead of `runAsWorkspaceId`
// and pass both X-Dust-Workspace-Id and X-Dust-Group-Ids.

const res = await this._fetchWithError(
`${this._url}/projects/${projectId}/runs/stream`,
{
Expand Down Expand Up @@ -614,20 +635,7 @@ export class CoreAPI {
payload: {
query: string;
topK: number;
filter?: {
tags: {
in?: string[] | null;
not?: string[] | null;
};
parents?: {
in?: string[] | null;
not?: string[] | null;
};
timestamp?: {
gt?: number | null;
lt?: number | null;
};
} | null;
filter?: CoreAPISearchFilter | null;
fullText: boolean;
credentials: { [key: string]: string };
target_document_tokens?: number | null;
Expand Down

0 comments on commit dc9948c

Please sign in to comment.