Skip to content

Commit

Permalink
✨ refactor: Integrate Capabilities into Agent File Uploads and Tool H…
Browse files Browse the repository at this point in the history
…andling (danny-avila#5048)

* refactor: support drag/drop files for agents, handle undefined tool_resource edge cases

* refactor: consolidate endpoints config logic to dedicated getter

* refactor: Enhance agent tools loading logic to respect capabilities and filter tools accordingly

* refactor: Integrate endpoint capabilities into file upload dropdown for dynamic resource handling

* refactor: Implement capability checks for agent file upload operations

* fix: non-image tool_resource check
  • Loading branch information
danny-avila authored and olivierhub committed Dec 20, 2024
1 parent 6109f0f commit 291dc67
Show file tree
Hide file tree
Showing 17 changed files with 448 additions and 188 deletions.
3 changes: 2 additions & 1 deletion api/app/clients/tools/util/handleTools.js
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ const loadTools = async ({

const toolContextMap = {};
const remainingTools = [];
const appTools = options.req?.app?.locals?.availableTools ?? {};

for (const tool of tools) {
if (tool === Tools.execute_code) {
Expand Down Expand Up @@ -259,7 +260,7 @@ const loadTools = async ({
return createFileSearchTool({ req: options.req, files, entity_id: agent?.id });
};
continue;
} else if (mcpToolPattern.test(tool)) {
} else if (tool && appTools[tool] && mcpToolPattern.test(tool)) {
requestedTools[tool] = async () =>
createMCPTool({
req: options.req,
Expand Down
66 changes: 2 additions & 64 deletions api/server/controllers/EndpointController.js
Original file line number Diff line number Diff line change
@@ -1,69 +1,7 @@
const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider');
const { loadDefaultEndpointsConfig, loadConfigEndpoints } = require('~/server/services/Config');
const { getLogStores } = require('~/cache');
const { getEndpointsConfig } = require('~/server/services/Config');

async function endpointController(req, res) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
if (cachedEndpointsConfig) {
res.send(cachedEndpointsConfig);
return;
}

const defaultEndpointsConfig = await loadDefaultEndpointsConfig(req);
const customConfigEndpoints = await loadConfigEndpoints(req);

/** @type {TEndpointsConfig} */
const mergedConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints };
if (mergedConfig[EModelEndpoint.assistants] && req.app.locals?.[EModelEndpoint.assistants]) {
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
req.app.locals[EModelEndpoint.assistants];

mergedConfig[EModelEndpoint.assistants] = {
...mergedConfig[EModelEndpoint.assistants],
version,
retrievalModels,
disableBuilder,
capabilities,
};
}
if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) {
const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents];

mergedConfig[EModelEndpoint.agents] = {
...mergedConfig[EModelEndpoint.agents],
disableBuilder,
capabilities,
};
}

if (
mergedConfig[EModelEndpoint.azureAssistants] &&
req.app.locals?.[EModelEndpoint.azureAssistants]
) {
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
req.app.locals[EModelEndpoint.azureAssistants];

mergedConfig[EModelEndpoint.azureAssistants] = {
...mergedConfig[EModelEndpoint.azureAssistants],
version,
retrievalModels,
disableBuilder,
capabilities,
};
}

if (mergedConfig[EModelEndpoint.bedrock] && req.app.locals?.[EModelEndpoint.bedrock]) {
const { availableRegions } = req.app.locals[EModelEndpoint.bedrock];
mergedConfig[EModelEndpoint.bedrock] = {
...mergedConfig[EModelEndpoint.bedrock],
availableRegions,
};
}

const endpointsConfig = orderEndpointsConfig(mergedConfig);

await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig);
const endpointsConfig = await getEndpointsConfig(req);
res.send(JSON.stringify(endpointsConfig));
}

Expand Down
10 changes: 3 additions & 7 deletions api/server/controllers/assistants/helpers.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
const {
CacheKeys,
SystemRoles,
EModelEndpoint,
defaultOrderQuery,
Expand All @@ -9,7 +8,7 @@ const {
initializeClient: initAzureClient,
} = require('~/server/services/Endpoints/azureAssistants');
const { initializeClient } = require('~/server/services/Endpoints/assistants');
const { getLogStores } = require('~/cache');
const { getEndpointsConfig } = require('~/server/services/Config');

/**
* @param {Express.Request} req
Expand All @@ -23,11 +22,8 @@ const getCurrentVersion = async (req, endpoint) => {
version = `v${req.body.version}`;
}
if (!version && endpoint) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
version = `v${
cachedEndpointsConfig?.[endpoint]?.version ?? defaultAssistantsVersion[endpoint]
}`;
const endpointsConfig = await getEndpointsConfig(req);
version = `v${endpointsConfig?.[endpoint]?.version ?? defaultAssistantsVersion[endpoint]}`;
}
if (!version?.startsWith('v') && version.length !== 2) {
throw new Error(`[${req.baseUrl}] Invalid version: ${version}`);
Expand Down
75 changes: 75 additions & 0 deletions api/server/services/Config/getEndpointsConfig.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
const { CacheKeys, EModelEndpoint, orderEndpointsConfig } = require('librechat-data-provider');
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');
const loadConfigEndpoints = require('./loadConfigEndpoints');
const getLogStores = require('~/cache/getLogStores');

/**
*
* @param {ServerRequest} req
* @returns {Promise<TEndpointsConfig>}
*/
async function getEndpointsConfig(req) {
const cache = getLogStores(CacheKeys.CONFIG_STORE);
const cachedEndpointsConfig = await cache.get(CacheKeys.ENDPOINT_CONFIG);
if (cachedEndpointsConfig) {
return cachedEndpointsConfig;
}

const defaultEndpointsConfig = await loadDefaultEndpointsConfig(req);
const customConfigEndpoints = await loadConfigEndpoints(req);

/** @type {TEndpointsConfig} */
const mergedConfig = { ...defaultEndpointsConfig, ...customConfigEndpoints };
if (mergedConfig[EModelEndpoint.assistants] && req.app.locals?.[EModelEndpoint.assistants]) {
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
req.app.locals[EModelEndpoint.assistants];

mergedConfig[EModelEndpoint.assistants] = {
...mergedConfig[EModelEndpoint.assistants],
version,
retrievalModels,
disableBuilder,
capabilities,
};
}
if (mergedConfig[EModelEndpoint.agents] && req.app.locals?.[EModelEndpoint.agents]) {
const { disableBuilder, capabilities, ..._rest } = req.app.locals[EModelEndpoint.agents];

mergedConfig[EModelEndpoint.agents] = {
...mergedConfig[EModelEndpoint.agents],
disableBuilder,
capabilities,
};
}

if (
mergedConfig[EModelEndpoint.azureAssistants] &&
req.app.locals?.[EModelEndpoint.azureAssistants]
) {
const { disableBuilder, retrievalModels, capabilities, version, ..._rest } =
req.app.locals[EModelEndpoint.azureAssistants];

mergedConfig[EModelEndpoint.azureAssistants] = {
...mergedConfig[EModelEndpoint.azureAssistants],
version,
retrievalModels,
disableBuilder,
capabilities,
};
}

if (mergedConfig[EModelEndpoint.bedrock] && req.app.locals?.[EModelEndpoint.bedrock]) {
const { availableRegions } = req.app.locals[EModelEndpoint.bedrock];
mergedConfig[EModelEndpoint.bedrock] = {
...mergedConfig[EModelEndpoint.bedrock],
availableRegions,
};
}

const endpointsConfig = orderEndpointsConfig(mergedConfig);

await cache.set(CacheKeys.ENDPOINT_CONFIG, endpointsConfig);
return endpointsConfig;
}

module.exports = { getEndpointsConfig };
6 changes: 2 additions & 4 deletions api/server/services/Config/index.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@ const getCustomConfig = require('./getCustomConfig');
const loadCustomConfig = require('./loadCustomConfig');
const loadConfigModels = require('./loadConfigModels');
const loadDefaultModels = require('./loadDefaultModels');
const getEndpointsConfig = require('./getEndpointsConfig');
const loadOverrideConfig = require('./loadOverrideConfig');
const loadAsyncEndpoints = require('./loadAsyncEndpoints');
const loadConfigEndpoints = require('./loadConfigEndpoints');
const loadDefaultEndpointsConfig = require('./loadDefaultEConfig');

module.exports = {
config,
Expand All @@ -16,6 +15,5 @@ module.exports = {
loadOverrideConfig,
loadAsyncEndpoints,
...getCustomConfig,
loadConfigEndpoints,
loadDefaultEndpointsConfig,
...getEndpointsConfig,
};
31 changes: 30 additions & 1 deletion api/server/services/Files/process.js
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ const {
EToolResources,
mergeFileConfig,
hostImageIdSuffix,
AgentCapabilities,
checkOpenAIStorage,
removeNullishValues,
hostImageNamePrefix,
Expand All @@ -27,6 +28,7 @@ const { addResourceFileId, deleteResourceFileId } = require('~/server/controller
const { addAgentResourceFile, removeAgentResourceFiles } = require('~/models/Agent');
const { getOpenAIClient } = require('~/server/controllers/assistants/helpers');
const { createFile, updateFileUsage, deleteFiles } = require('~/models/File');
const { getEndpointsConfig } = require('~/server/services/Config');
const { loadAuthValues } = require('~/app/clients/tools/util');
const { LB_QueueAsyncCall } = require('~/server/utils/queue');
const { getStrategyFunctions } = require('./strategies');
Expand Down Expand Up @@ -451,6 +453,17 @@ const processFileUpload = async ({ req, res, metadata }) => {
res.status(200).json({ message: 'File uploaded and processed successfully', ...result });
};

/**
* @param {ServerRequest} req
* @param {AgentCapabilities} capability
* @returns {Promise<boolean>}
*/
const checkCapability = async (req, capability) => {
const endpointsConfig = await getEndpointsConfig(req);
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
return capabilities.includes(capability);
};

/**
* Applies the current strategy for file uploads.
* Saves file metadata to the database with an expiry TTL.
Expand Down Expand Up @@ -478,9 +491,20 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
throw new Error('No agent ID provided for agent file upload');
}

const isImage = file.mimetype.startsWith('image');
if (!isImage && !tool_resource) {
/** Note: this needs to be removed when we can support files to providers */
throw new Error('No tool resource provided for non-image agent file upload');
}

let fileInfoMetadata;
const entity_id = messageAttachment === true ? undefined : agent_id;

if (tool_resource === EToolResources.execute_code) {
const isCodeEnabled = await checkCapability(req, AgentCapabilities.execute_code);
if (!isCodeEnabled) {
throw new Error('Code execution is not enabled for Agents');
}
const { handleFileUpload: uploadCodeEnvFile } = getStrategyFunctions(FileSources.execute_code);
const result = await loadAuthValues({ userId: req.user.id, authFields: [EnvVar.CODE_API_KEY] });
const stream = fs.createReadStream(file.path);
Expand All @@ -492,6 +516,11 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
entity_id,
});
fileInfoMetadata = { fileIdentifier };
} else if (tool_resource === EToolResources.file_search) {
const isFileSearchEnabled = await checkCapability(req, AgentCapabilities.file_search);
if (!isFileSearchEnabled) {
throw new Error('File search is not enabled for Agents');
}
}

const source =
Expand Down Expand Up @@ -527,7 +556,7 @@ const processAgentFileUpload = async ({ req, res, metadata }) => {
});
}

if (file.mimetype.startsWith('image')) {
if (isImage) {
const result = await processImageFile({
req,
file,
Expand Down
Loading

0 comments on commit 291dc67

Please sign in to comment.