Skip to content

Commit

Permalink
fix:ensure correct sequence in tool calls (explodinggradients#1371)
Browse files Browse the repository at this point in the history
  • Loading branch information
shahules786 authored Sep 27, 2024
1 parent e63a776 commit 1a64d9f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 15 deletions.
28 changes: 21 additions & 7 deletions docs/concepts/metrics/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,33 @@ from ragas.messages import HumanMessage,AIMessage,ToolMessage,ToolCall
from ragas.metrics._tool_call_accuracy import ToolCallAccuracy


sample = MultiTurnSample(user_input=[
HumanMessage(content="Hey, book a table at the nearest best Chinese restaurant for 8:00pm"),
AIMessage(content="Sure, let me find the best options for you.", tool_calls=[
ToolCall(name="restaurant_search", args={"cuisine": "Asian", "time": "8:00pm"})
sample = [
HumanMessage(content="What's the weather like in New York right now?"),
AIMessage(content="The current temperature in New York is 75°F and it's partly cloudy.", tool_calls=[
ToolCall(name="weather_check", args={"location": "New York"})
]),
],
reference_tool_calls=[ToolCall(name="restaurant_book", args={"name": "Golden", "time": "8:00pm"})
])
HumanMessage(content="Can you translate that to Celsius?"),
AIMessage(content="Let me convert that to Celsius for you.", tool_calls=[
ToolCall(name="temperature_conversion", args={"temperature_fahrenheit": 75})
]),
ToolMessage(content="75°F is approximately 23.9°C."),
AIMessage(content="75°F is approximately 23.9°C.")
]

sampl2 = MultiTurnSample(
user_input=sample,
reference_tool_calls=[
ToolCall(name="weather_check", args={"location": "New York"}),
ToolCall(name="temperature_conversion", args={"temperature_fahrenheit": 75})
]
)

scorer = ToolCallAccuracy()
await metric.multi_turn_ascore(sample)
```

The tool call sequence specified in `reference_tool_calls` is used as the ideal outcome. If the tool calls made by the AI does not the the order or sequence of the `reference_tool_calls`, the metric will return a score of 0. This helps to ensure that the AI is able to identify and call the required tools in the correct order to complete a given task.

By default the tool names and arguments are compared using exact string matching. But sometimes this might not be optimal, for example if the args are natural language strings. You can also use any ragas metrics (values between 0 and 1) as distance measure to identify if a retrieved context is relevant or not. For example,

```python
Expand Down
39 changes: 31 additions & 8 deletions src/ragas/metrics/_tool_call_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings
from dataclasses import dataclass, field

import numpy as np

from ragas.dataset_schema import MultiTurnSample, SingleTurnSample
from ragas.messages import AIMessage
Expand Down Expand Up @@ -49,29 +48,53 @@ async def _get_arg_score(

return score / len(refs.keys())

def is_sequence_aligned(
self, pred_sequence: t.List[str], ref_sequence: t.List[str]
) -> bool:
ref_index = 0 # Index to track position in reference sequence
for pred in pred_sequence:
if ref_index < len(ref_sequence) and pred == ref_sequence[ref_index]:
ref_index += 1
if ref_index == len(ref_sequence):
return True
return False

async def _multi_turn_ascore(
self, sample: MultiTurnSample, callbacks: Callbacks
) -> float:
assert sample.reference_tool_calls is not None, "Reference is not set"

if isinstance(sample.user_input[-1], AIMessage):
if sample.user_input[-1].tool_calls is None:
return np.nan
pred_tool_calls = []
for item in sample.user_input:
if isinstance(item, AIMessage) and item.tool_calls is not None:
pred_tool_calls.extend(item.tool_calls)

tool_call_pred_sequence = [tool_call.name for tool_call in pred_tool_calls]
tool_call_ref_sequence = [
tool_call.name for tool_call in sample.reference_tool_calls
]

sequence_aligned = int(
self.is_sequence_aligned(tool_call_pred_sequence, tool_call_ref_sequence)
)

if pred_tool_calls:
score = 0.0
reference_tool_calls = sample.reference_tool_calls
for ref_tool_call in reference_tool_calls:
for pred_tool_call in sample.user_input[-1].tool_calls:
for pred_tool_call in pred_tool_calls:
if ref_tool_call.name == pred_tool_call.name:
arg_score = await self._get_arg_score(
pred_tool_call.args, ref_tool_call.args, callbacks
)
score += arg_score

return score / len(reference_tool_calls)
score /= len(reference_tool_calls)
else:
warnings.warn("Last message is not an AIMessage with ToolCalls")
return np.nan
warnings.warn("No tool calls found in the user input")
return 0.0

return score * sequence_aligned

async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
return await self._multi_turn_ascore(MultiTurnSample(**row), callbacks)

0 comments on commit 1a64d9f

Please sign in to comment.