Skip to content

Commit

Permalink
refactor: Enhance agent tools loading logic to respect capabilities a…
Browse files Browse the repository at this point in the history
…nd filter tools accordingly
  • Loading branch information
danny-avila committed Dec 19, 2024
1 parent 47f79a8 commit 16e3b7a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 46 deletions.
3 changes: 2 additions & 1 deletion api/app/clients/tools/util/handleTools.js
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ const loadTools = async ({

const toolContextMap = {};
const remainingTools = [];
const appTools = options.req?.app?.locals?.availableTools ?? {};

for (const tool of tools) {
if (tool === Tools.execute_code) {
Expand Down Expand Up @@ -259,7 +260,7 @@ const loadTools = async ({
return createFileSearchTool({ req: options.req, files, entity_id: agent?.id });
};
continue;
} else if (mcpToolPattern.test(tool)) {
} else if (tool && appTools[tool] && mcpToolPattern.test(tool)) {
requestedTools[tool] = async () =>
createMCPTool({
req: options.req,
Expand Down
131 changes: 86 additions & 45 deletions api/server/services/ToolService.js
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@ const {
ErrorTypes,
ContentTypes,
imageGenTools,
EModelEndpoint,
actionDelimiter,
ImageVisionTool,
openapiToFunction,
AgentCapabilities,
validateAndParseOpenAPISpec,
} = require('librechat-data-provider');
const { processFileURL, uploadImageBuffer } = require('~/server/services/Files/process');
const { loadActionSets, createActionTool, domainParser } = require('./ActionService');
const { getEndpointsConfig } = require('~/server/services/Config');
const { recordUsage } = require('~/server/services/Threads');
const { loadTools } = require('~/app/clients/tools/util');
const { redactMessage } = require('~/config/parsers');
Expand Down Expand Up @@ -383,11 +386,37 @@ async function loadAgentTools({ req, agent, tool_resources, openAIApiKey }) {
if (!agent.tools || agent.tools.length === 0) {
return {};
}

const endpointsConfig = await getEndpointsConfig(req);
const capabilities = endpointsConfig?.[EModelEndpoint.agents]?.capabilities ?? [];
const areToolsEnabled = capabilities.includes(AgentCapabilities.tools);
if (!areToolsEnabled) {
logger.debug('Tools are not enabled for this agent.');
return {};
}

const isFileSearchEnabled = capabilities.includes(AgentCapabilities.file_search);
const isCodeEnabled = capabilities.includes(AgentCapabilities.execute_code);
const areActionsEnabled = capabilities.includes(AgentCapabilities.actions);

const _agentTools = agent.tools?.filter((tool) => {
if (tool === Tools.file_search && !isFileSearchEnabled) {
return false;
} else if (tool === Tools.execute_code && !isCodeEnabled) {
return false;
}
return true;
});

if (!_agentTools || _agentTools.length === 0) {
return {};
}

const { loadedTools, toolContextMap } = await loadTools({
agent,
functions: true,
user: req.user.id,
tools: agent.tools,
tools: _agentTools,
options: {
req,
openAIApiKey,
Expand Down Expand Up @@ -434,62 +463,74 @@ async function loadAgentTools({ req, agent, tool_resources, openAIApiKey }) {
return map;
}, {});

if (!areActionsEnabled) {
return {
tools: agentTools,
toolContextMap,
};
}

let actionSets = [];
const ActionToolMap = {};

for (const toolName of agent.tools) {
if (!ToolMap[toolName]) {
if (!actionSets.length) {
actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
}
for (const toolName of _agentTools) {
if (ToolMap[toolName]) {
continue;
}

let actionSet = null;
let currentDomain = '';
for (let action of actionSets) {
const domain = await domainParser(req, action.metadata.domain, true);
if (toolName.includes(domain)) {
currentDomain = domain;
actionSet = action;
break;
}
if (!actionSets.length) {
actionSets = (await loadActionSets({ agent_id: agent.id })) ?? [];
}

let actionSet = null;
let currentDomain = '';
for (let action of actionSets) {
const domain = await domainParser(req, action.metadata.domain, true);
if (toolName.includes(domain)) {
currentDomain = domain;
actionSet = action;
break;
}
}

if (actionSet) {
const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
if (validationResult.spec) {
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
validationResult.spec,
true,
if (!actionSet) {
continue;
}

const validationResult = validateAndParseOpenAPISpec(actionSet.metadata.raw_spec);
if (validationResult.spec) {
const { requestBuilders, functionSignatures, zodSchemas } = openapiToFunction(
validationResult.spec,
true,
);
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
const requestBuilder = requestBuilders[functionName];
const zodSchema = zodSchemas[functionName];

if (requestBuilder) {
const tool = await createActionTool({
action: actionSet,
requestBuilder,
zodSchema,
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
);
const functionName = toolName.replace(`${actionDelimiter}${currentDomain}`, '');
const functionSig = functionSignatures.find((sig) => sig.name === functionName);
const requestBuilder = requestBuilders[functionName];
const zodSchema = zodSchemas[functionName];

if (requestBuilder) {
const tool = await createActionTool({
action: actionSet,
requestBuilder,
zodSchema,
name: toolName,
description: functionSig.description,
});
if (!tool) {
logger.warn(
`Invalid action: user: ${req.user.id} | agent_id: ${agent.id} | toolName: ${toolName}`,
);
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
}
throw new Error(`{"type":"${ErrorTypes.INVALID_ACTION}"}`);
}
agentTools.push(tool);
ActionToolMap[toolName] = tool;
}
}
}

if (agent.tools.length > 0 && agentTools.length === 0) {
throw new Error('No tools found for the specified tool calls.');
if (_agentTools.length > 0 && agentTools.length === 0) {
logger.warn(`No tools found for the specified tool calls: ${_agentTools.join(', ')}`);
return {};
}

return {
Expand Down

0 comments on commit 16e3b7a

Please sign in to comment.