Skip to content

Commit

Permalink
More improvements in vaults api
Browse files Browse the repository at this point in the history
  • Loading branch information
tdraier committed Aug 12, 2024
1 parent 848afa9 commit 222d4ba
Show file tree
Hide file tree
Showing 18 changed files with 167 additions and 124 deletions.
52 changes: 52 additions & 0 deletions front/lib/api/agent_data_sources.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import type {
ConnectorProvider,
DataSourceType,
DataSourceViewType,
ModelId,
} from "@dust-tt/types";
import { Sequelize } from "sequelize";
Expand All @@ -10,6 +11,7 @@ import { AgentDataSourceConfiguration } from "@app/lib/models/assistant/actions/
import { AgentRetrievalConfiguration } from "@app/lib/models/assistant/actions/retrieval";
import { AgentConfiguration } from "@app/lib/models/assistant/agent";
import { DataSource } from "@app/lib/models/data_source";
import { DataSourceViewModel } from "@app/lib/resources/storage/models/data_source_view";

export type DataSourcesUsageByAgent = Record<ModelId, number>;

Expand Down Expand Up @@ -128,3 +130,53 @@ export async function getDataSourceUsage({
],
});
}

export async function getDataSourceViewUsage({
auth,
dataSourceView,
}: {
auth: Authenticator;
dataSourceView: DataSourceViewType;
}): Promise<number> {
const owner = auth.workspace();

// This condition is critical it checks that we can identify the workspace and that the current
// auth is a user for this workspace. Checking `auth.isUser()` is critical as it would otherwise
// be possible to access data sources without being authenticated.
if (!owner || !auth.isUser()) {
return 0;
}

return AgentDataSourceConfiguration.count({
where: {
dataSourceViewId: dataSourceView.id,
},
include: [
{
model: DataSourceViewModel,
as: "dataSourceView",
where: {
workspaceId: owner.id,
},
attributes: [],
},
{
model: AgentRetrievalConfiguration,
as: "agent_retrieval_configuration",
attributes: [],
required: true,
include: [
{
model: AgentConfiguration,
as: "agent_configuration",
attributes: [],
required: true,
where: {
status: "active",
},
},
],
},
],
});
}
40 changes: 16 additions & 24 deletions front/lib/api/vaults.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ import { DataSourceViewResource } from "@app/lib/resources/data_source_view_reso
import type { VaultResource } from "@app/lib/resources/vault_resource";
import logger from "@app/logger/logger";

export const getDataSourceInfo = (
export const getDataSourceInfo = async (
auth: Authenticator,
dataSource: DataSourceResource
): DataSourceOrViewInfo => {
): Promise<DataSourceOrViewInfo> => {
return {
...dataSource.toJSON(),
sId: dataSource.name,
usage: 0, // TODO: Implement usage calculation
usage: await dataSource.getUsagesByAgents(auth),
category: getDataSourceCategory(dataSource),
};
};
Expand All @@ -34,16 +34,18 @@ export const getDataSourceInfos = async (
): Promise<DataSourceOrViewInfo[]> => {
const dataSources = await DataSourceResource.listByVault(auth, vault);

return dataSources.map((dataSource) => getDataSourceInfo(dataSource));
return Promise.all(
dataSources.map((dataSource) => getDataSourceInfo(auth, dataSource))
);
};

export const getDataSourceViewInfo = (
export const getDataSourceViewInfo = async (
auth: Authenticator,
dataSourceView: DataSourceViewResource
): DataSourceOrViewInfo => {
): Promise<DataSourceOrViewInfo> => {
return {
...(dataSourceView.dataSource as DataSourceResource).toJSON(),
...dataSourceView.toJSON(),
usage: 0, // TODO: Implement usage calculation
usage: await dataSourceView.getUsagesByAgents(auth),
category: getDataSourceCategory(
dataSourceView.dataSource as DataSourceResource
),
Expand All @@ -56,29 +58,19 @@ export const getDataSourceViewsInfo = async (
): Promise<DataSourceOrViewInfo[]> => {
const dataSourceViews = await DataSourceViewResource.listByVault(auth, vault);

return dataSourceViews.map((view) => getDataSourceViewInfo(view));
return Promise.all(
dataSourceViews.map((view) => getDataSourceViewInfo(auth, view))
);
};

export const isFolderDataSource = (dataSource: DataSourceResource): boolean =>
!dataSource.connectorProvider;

export const isWebfolderDataSource = (
dataSource: DataSourceResource
): boolean => dataSource.connectorProvider === "webcrawler";

export const isConnectedDataSource = (
dataSource: DataSourceResource
): boolean =>
!isFolderDataSource(dataSource) && !isWebfolderDataSource(dataSource);

export const getDataSourceCategory = (
dataSource: DataSourceResource
): DataSourceOrViewCategory => {
if (isFolderDataSource(dataSource)) {
if (dataSource.isFolder()) {
return "files";
}

if (isWebfolderDataSource(dataSource)) {
if (dataSource.isWebcrawler()) {
return "webfolder";
}

Expand Down
14 changes: 14 additions & 0 deletions front/lib/resources/data_source_resource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import type {
} from "sequelize";
import { Op } from "sequelize";

import { getDataSourceUsage } from "@app/lib/api/agent_data_sources";
import type { Authenticator } from "@app/lib/auth";
import { AgentDataSourceConfiguration } from "@app/lib/models/assistant/actions/data_sources";
import { DataSource } from "@app/lib/models/data_source";
Expand Down Expand Up @@ -235,11 +236,24 @@ export class DataSourceResource extends ResourceWithVault<DataSource> {
);
}

isFolder() {
return !this.connectorProvider;
}

isWebcrawler() {
return this.connectorProvider === "webcrawler";
}

getUsagesByAgents(auth: Authenticator) {
return getDataSourceUsage({ auth, dataSource: this.toJSON() });
}

// Serialization.

toJSON(): DataSourceType {
return {
id: this.id,
sId: this.name, // TODO(thomas 20240812) Migrate to a real sId
createdAt: this.createdAt.getTime(),
name: this.name,
description: this.description,
Expand Down
25 changes: 14 additions & 11 deletions front/lib/resources/data_source_view_resource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import type {
Transaction,
} from "sequelize";

import { getDataSourceViewUsage } from "@app/lib/api/agent_data_sources";
import type { Authenticator } from "@app/lib/auth";
import { DataSourceResource } from "@app/lib/resources/data_source_resource";
import type { ResourceFindOptions } from "@app/lib/resources/resource_with_vault";
Expand Down Expand Up @@ -156,14 +157,6 @@ export class DataSourceViewResource extends ResourceWithVault<DataSourceViewMode
return dataSource ?? null;
}

// Peer fetching.

async fetchDataSource(
auth: Authenticator
): Promise<DataSourceResource | null> {
return DataSourceResource.fetchByModelIdWithAuth(auth, this.dataSourceId);
}

// Updating.
async updateParents(
auth: Authenticator,
Expand Down Expand Up @@ -235,8 +228,8 @@ export class DataSourceViewResource extends ResourceWithVault<DataSourceViewMode

// Getters.

get dataSource(): DataSourceResource | undefined {
return this.ds;
get dataSource(): DataSourceResource {
return this.ds as DataSourceResource;
}

// sId logic.
Expand Down Expand Up @@ -265,14 +258,24 @@ export class DataSourceViewResource extends ResourceWithVault<DataSourceViewMode
return isResourceSId("data_source_view", sId);
}

getUsagesByAgents = async (auth: Authenticator) => {
return getDataSourceViewUsage({ auth, dataSourceView: this.toJSON() });
};

// Serialization.

toJSON(): DataSourceViewType {
return {
id: this.id,
sId: this.sId,
createdAt: this.createdAt.getTime(),
parentsIn: this.parentsIn,
sId: this.sId,
updatedAt: this.updatedAt.getTime(),
connectorId: this.dataSource.connectorId,
connectorProvider: this.dataSource.connectorProvider,
name: this.dataSource.name,
description: this.dataSource.description,
dustAPIProjectId: this.dataSource.dustAPIProjectId,
};
}
}
27 changes: 10 additions & 17 deletions front/lib/resources/group_resource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,8 @@ export class GroupResource extends BaseResource<GroupModel> {

async addMembers(
auth: Authenticator,
users: UserType[],
transaction?: Transaction
users: UserType[]
): Promise<Result<undefined, Error>> {
// Checking that the user is a member of the workspace.
const owner = auth.getNonNullableWorkspace();

if (users.length === 0) {
Expand All @@ -404,7 +402,6 @@ export class GroupResource extends BaseResource<GroupModel> {
const workspaceMemberships = await MembershipResource.getActiveMemberships({
users: userResources,
workspace: owner,
transaction,
});

if (
Expand All @@ -425,13 +422,15 @@ export class GroupResource extends BaseResource<GroupModel> {
// Check if the user is already a member of the group.
const activeMembers = await this.getActiveMembers(auth);
const activeMembersIds = activeMembers.map((m) => m.sId);
const alreadyActive = userIds.filter((userId) =>
const alreadyActiveUserIds = userIds.filter((userId) =>
activeMembersIds.includes(userId)
);
if (alreadyActive.length > 0) {
return alreadyActive.length === 1
? new Err(new Error(`User ${alreadyActive} already member.`))
: new Err(new Error(`Users ${alreadyActive} already members.`));
if (alreadyActiveUserIds.length > 0) {
return alreadyActiveUserIds.length === 1
? new Err(new Error(`User ${alreadyActiveUserIds} is already member.`))
: new Err(
new Error(`Users ${alreadyActiveUserIds} are already members.`)
);
}

// Create a new membership.
Expand All @@ -441,8 +440,7 @@ export class GroupResource extends BaseResource<GroupModel> {
userId: user.id,
workspaceId: owner.id,
startAt: new Date(),
})),
{ transaction }
}))
);

return new Ok(undefined);
Expand All @@ -457,12 +455,9 @@ export class GroupResource extends BaseResource<GroupModel> {

async removeMembers(
auth: Authenticator,
users: UserType[],
transaction?: Transaction
users: UserType[]
): Promise<Result<undefined, Error>> {
// Checking that the user is a member of the workspace.
const owner = auth.getNonNullableWorkspace();

if (users.length === 0) {
return new Ok(undefined);
}
Expand All @@ -479,7 +474,6 @@ export class GroupResource extends BaseResource<GroupModel> {
const workspaceMemberships = await MembershipResource.getActiveMemberships({
users: userResources,
workspace: owner,
transaction,
});

if (workspaceMemberships.length !== userIds.length) {
Expand Down Expand Up @@ -518,7 +512,6 @@ export class GroupResource extends BaseResource<GroupModel> {
startAt: { [Op.lte]: new Date() },
[Op.or]: [{ endAt: null }, { endAt: { [Op.gt]: new Date() } }],
},
transaction,
}
);

Expand Down
16 changes: 12 additions & 4 deletions front/lib/resources/vault_resource.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ export class VaultResource extends BaseResource<VaultModel> {
},
});

return vaults.map((vault) => new this(VaultModel, vault.get()));
return vaults
.map((vault) => new this(VaultModel, vault.get()))
.filter(
(vault) => auth.isAdmin() || auth.hasPermission([vault.acl()], "read")
);
}

static async fetchWorkspaceSystemVault(
Expand Down Expand Up @@ -159,18 +163,22 @@ export class VaultResource extends BaseResource<VaultModel> {
return null;
}

const vault = await this.model.findOne({
const vaultModel = await this.model.findOne({
where: {
id: vaultModelId,
workspaceId: owner.id,
},
});

if (!vault) {
if (!vaultModel) {
return null;
}
const vault = new this(VaultModel, vaultModel.get());
if (!auth.isAdmin() && !auth.hasPermission([vault.acl()], "read")) {
return null;
}

return new this(VaultModel, vault.get());
return vault;
}

static async isNameAvailable(
Expand Down
5 changes: 1 addition & 4 deletions front/pages/api/registry/[type]/lookup.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,7 @@ async function handleDataSourceView(

// TODO(2024-08-02 flav) Uncomment.
// if (hasAccessToDataSourceView) {
const dataSource = await dataSourceView.fetchDataSource(auth);
if (!dataSource) {
return new Err(new Error("Data source not found for view."));
}
const dataSource = dataSourceView.dataSource;
return new Ok({
project_id: parseInt(dataSource.dustAPIProjectId),
data_source_id: dataSource.name,
Expand Down
11 changes: 1 addition & 10 deletions front/pages/api/w/[wId]/data_sources/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,7 @@ async function handler(
}

res.status(201).json({
dataSource: {
id: ds.id,
createdAt: ds.createdAt.getTime(),
name: ds.name,
description: ds.description,
dustAPIProjectId: ds.dustAPIProjectId,
assistantDefaultSelected: ds.assistantDefaultSelected,
connectorId: null,
connectorProvider: null,
},
dataSource: ds.toJSON(),
});
return;

Expand Down
Loading

0 comments on commit 222d4ba

Please sign in to comment.