Skip to content

Commit

Permalink
Merge pull request #12 from awslabs/add-hooks
Browse files Browse the repository at this point in the history
Add hooks
  • Loading branch information
tonykchen authored Apr 1, 2024
2 parents 75ba77d + 1aa486d commit 321c97e
Show file tree
Hide file tree
Showing 18 changed files with 262 additions and 59 deletions.
2 changes: 2 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ tests:
expected_results:
initial_prompt:
max_turns:
hook:
```
`evaluator` _(map)_
Expand All @@ -32,5 +33,6 @@ A list of tests. Each test maps with the following fields:
- `expected_results` _(list of strings)_: The expected results for the test.
- `initial_prompt` _(string; optional)_: The first message that is sent to the agent, which starts the conversation. If unspecified, the message will be generated based on the `steps` provided.
- `max_turns` _(integer; optional)_: The maximum number of user-agent exchanges before the test fails. The default is `2`.
- `hook` _(string; optional)_: The module path to an evaluation hook.

---
1 change: 1 addition & 0 deletions docs/reference/hook.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: src.agenteval.hook
107 changes: 106 additions & 1 deletion docs/user_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ To begin, initialize a test plan.
agenteval init
```

This will create an `agenteval.yaml` file in the current directory.
This will create an `agenteval.yml` file in the current directory.

```yaml
evaluator:
Expand Down Expand Up @@ -99,3 +99,108 @@ tests:
- The agent returns the details on claim-006.
initial_prompt: Can you let me know which claims are open?
```

## Evaluation hooks
You can specify hooks that run before and/or after evaluating a test. This is useful for performing integration testing, as well as any setup or cleanup tasks required.

To create hooks, you will define a subclass of [Hook](reference/hook.md#src.agenteval.hook.Hook).

For a hook that runs *before* evaluation, implement a `pre_evaluate` method. In this method, you have access to the [Test](reference/test.md#src.agenteval.test.Test) and [TraceHandler](reference/trace_handler.md#src.agenteval.trace_handler.TraceHandler) via the `test` and `trace` arguments, respectively.

For a hook that runs *after* evaluation, implement a `post_evaluate` method. Similar to the `pre_evaluate` method, you have access to the `Test` and `TraceHandler`. You also have access to the [TestResult](reference/test_result.md#src.agenteval.test_result.TestResult) via the `test_result` argument. You may override the attributes of the `TestResult` if you plan to use this hook to perform additional testing, such as integration testing.

!!! example "my_evaluation_hook.py"

```python
from agenteval import Hook
class MyEvaluationHook(Hook):
def pre_evaluate(test, trace):
# implement logic here
def post_evaluate(test, test_result, trace):
# implement logic here
```

Once you have created your subclass, specify the hook for the test in `agenteval.yml`.

!!! example "agenteval.yml"

```yaml
tests:
- name: MakeReservation
hook: my_evaluation_hook.MyEvaluationHook
```

### Examples

#### Integration testing using `post_evaluate`

In this example, we will test an agent that can make dinner reservations. In addition to evaluating the conversation, we want to test that the reservation is written to the backend database. To do this, we will create a post evaluation hook that queries a PostgreSQL database for the reservation record. If the record is not found, we will override the `TestResult`.

!!! example "test_record_insert.py"

```python
import boto3
import json
import psycopg2
from agenteval import Hook
SECRET_NAME = "test-secret"
def get_db_secret() -> dict:
session = boto3.session.Session()
client = session.client(
service_name='secretsmanager',
)
get_secret_value_response = client.get_secret_value(
SecretId=SECRET_NAME
)
secret = get_secret_value_response['SecretString']
return json.loads(secret)
class TestRecordInsert(Hook):
def post_evaluate(test, test_result, trace):
# get database secret from AWS Secrets Manager
secret = get_db_secret()
# connect to database
conn = psycopg2.connect(
database=secret["dbname"],
user=secret["username"],
password=secret["password"],
host=secret["host"],
port=secret["port"]
)
# check if record is written to database
with conn.cursor() as cur:
cur.execute("SELECT * FROM reservations WHERE name = 'Bob'")
row = cur.fetchone()
# override the test result based on query result
if not row:
test_result.success = False
test_result.result = "Integration test failed"
test_result.reasoning = "Record was not inserted into the database"
```

Create a test that references the hook.

!!! example "agenteval.yml"

```yaml
tests:
- name: MakeReservation
steps:
- Ask agent to make a reservation under the name Bob for 7 PM.
expected_results:
- The agent confirms that a reservation has been made.
hook: test_record_insert.TestRecordInsert
```
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ nav:
- cli: reference/cli.md
- conversation_handler: reference/conversation_handler.md
- evaluator: reference/evaluator.md
- hook: reference/hook.md
- target: reference/target.md
- test: reference/test.md
- test_result: reference/test_result.md
Expand Down
3 changes: 3 additions & 0 deletions src/agenteval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from jinja2 import Environment, PackageLoader
from rich.logging import RichHandler

from .hook import Hook

__all__ = ["Hook"]
_LOG_LEVEL_ENV = "LOG_LEVEL"
_FORMAT = "%(message)s"

Expand Down
2 changes: 1 addition & 1 deletion src/agenteval/evaluators/claude/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _invoke_target(self, user_input) -> str:

return target_response.response

def run(self) -> TestResult:
def evaluate(self) -> TestResult:
success = False
eval_reasoning = ""
result = "Max turns reached."
Expand Down
33 changes: 26 additions & 7 deletions src/agenteval/evaluators/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
from botocore.client import BaseClient

from agenteval.conversation_handler import ConversationHandler
from agenteval.hook import Hook
from agenteval.targets import BaseTarget
from agenteval.test import Test
from agenteval.test_result import TestResult
from agenteval.trace_handler import TraceHandler
from agenteval.utils import import_class, validate_subclass

_BEDROCK_RUNTIME_BOTO3_NAME = "bedrock-runtime"

Expand All @@ -19,41 +21,58 @@ class BaseEvaluator(ABC):
classes.
Attributes:
test (Test): The test to conduct.
test (Test): The test case.
target (BaseTarget): The target agent being evaluated.
conversation (ConversationHandler): Conversation handler for capturing the interaction
between the evaluator (user) and target (agent).
trace (TraceHandler): Trace handler for capturing steps during evaluation.
test_result (TestResult): The result of the test which is set in `BaseEvaluator.run`.
"""

def __init__(self, test: Test, target: BaseTarget):
"""Initialize the evaluator instance for a given `Test` and `Target`.
Args:
test (Test): The test to conduct.
test (Test): The test case.
target (BaseTarget): The target agent being evaluated.
"""
self.test = test
self.target = target
self.conversation = ConversationHandler()
self.trace = TraceHandler(test_name=test.name)
self.test_result = None

@abstractmethod
def run(self) -> TestResult:
def evaluate(self) -> TestResult:
"""Conduct a test.
Returns:
TestResult: The result of the test.
"""
pass

def trace_run(self):
def _get_hook_cls(self, hook: Optional[str]) -> Optional[type[Hook]]:
if hook:
hook_cls = import_class(hook)
validate_subclass(hook_cls, Hook)
return hook_cls

def run(self) -> TestResult:
"""
Runs the evaluator within a trace context manager to
capture trace data.
Run the evaluator within a trace context manager and run hooks
if provided.
"""

hook_cls = self._get_hook_cls(self.test.hook)

with self.trace:
return self.run()
if hook_cls:
hook_cls.pre_evaluate(self.test, self.trace)
self.test_result = self.evaluate()
if hook_cls:
hook_cls.post_evaluate(self.test, self.test_result, self.trace)

return self.test_result


class AWSEvaluator(BaseEvaluator):
Expand Down
30 changes: 30 additions & 0 deletions src/agenteval/hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from agenteval.test import Test
from agenteval.test_result import TestResult
from agenteval.trace_handler import TraceHandler


class Hook:
"""An evaluation hook."""

def pre_evaluate(test: Test, trace: TraceHandler) -> None:
"""
Method called before evaluation. Can be used to perform any setup tasks.
Args:
test (Test): The test case.
trace (TraceHandler): Trace handler for capturing steps during evaluation.
"""
pass

def post_evaluate(test: Test, test_result: TestResult, trace: TraceHandler) -> None:
"""
Method called after evaluation. This may be used to perform integration testing
or clean up tasks.
Args:
test (Test): The test case.
test_result (TestResult): The result of the test, which can be overriden
by updating the attributes of this object.
trace (TraceHandler): Trace handler for capturing steps during evaluation.
"""
pass
27 changes: 7 additions & 20 deletions src/agenteval/plan.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import importlib
import logging
import os
import sys
from typing import Optional, Union
from typing import Optional

import yaml
from pydantic import BaseModel, model_validator
Expand All @@ -18,6 +17,7 @@
SageMakerEndpointTarget,
)
from agenteval.test import Test
from agenteval.utils import import_class, validate_subclass

_PLAN_FILE_NAME = "agenteval.yml"

Expand Down Expand Up @@ -86,10 +86,8 @@ def create_evaluator(self, test: Test, target: BaseTarget) -> BaseEvaluator:
if self.evaluator_config.type in _EVALUATOR_MAP:
evaluator_cls = _EVALUATOR_MAP[self.evaluator_config.type]
else:
evaluator_cls = self._get_class(self.evaluator_config)

if not issubclass(evaluator_cls, BaseEvaluator):
raise TypeError(f"{evaluator_cls} is not a Evaluator subclass")
evaluator_cls = import_class(self.evaluator_config.type)
validate_subclass(evaluator_cls, BaseEvaluator)

return evaluator_cls(
test=test,
Expand All @@ -103,23 +101,11 @@ def create_target(
if self.target_config.type in _TARGET_MAP:
target_cls = _TARGET_MAP[self.target_config.type]
else:
target_cls = self._get_class(self.target_config)

if not issubclass(target_cls, BaseTarget):
raise TypeError(f"{target_cls} is not a Target subclass")
target_cls = import_class(self.target_config.type)
validate_subclass(target_cls, BaseTarget)

return target_cls(**self.target_config.model_dump(exclude="type"))

def _get_class(self, config: Union[EvaluatorConfig, TargetConfig]) -> type:
module_path, class_name = config.type.rsplit(".", 1)
return self._import_class(module_path, class_name)

@staticmethod
def _import_class(module_path: str, class_name: str) -> type[BaseTarget]:
module = importlib.import_module(module_path)
target_cls = getattr(module, class_name)
return target_cls

@staticmethod
def _load_yaml(path: str) -> dict:
with open(path) as stream:
Expand All @@ -136,6 +122,7 @@ def _load_tests(test_config: list[dict]) -> list[Test]:
expected_results=t.get("expected_results"),
initial_prompt=t.get("initial_prompt"),
max_turns=t.get("max_turns", defaults.MAX_TURNS),
hook=t.get("hook"),
)
)
return tests
Expand Down
2 changes: 1 addition & 1 deletion src/agenteval/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def run_test(self, test):
target=target,
)

result = evaluator.trace_run()
result = evaluator.run()
if result.success is False:
self.num_failed += 1

Expand Down
12 changes: 7 additions & 5 deletions src/agenteval/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ class Test(BaseModel, validate_assignment=True):
"""A test case for an agent.
Attributes:
name: Name of the test
steps: List of step to perform for the test
expected_results: List of expected results for the test
initial_prompt: Optional initial prompt
max_turns: Maximum number of turns allowed for the test
name: Name of the test.
steps: List of step to perform for the test.
expected_results: List of expected results for the test.
initial_prompt: Optional initial prompt.
max_turns: Maximum number of turns allowed for the test.
hook: The module path to an evaluation hook.
"""

# do not collect as a test
Expand All @@ -22,3 +23,4 @@ class Test(BaseModel, validate_assignment=True):
expected_results: list[str]
initial_prompt: Optional[str] = None
max_turns: int
hook: Optional[str] = None
3 changes: 3 additions & 0 deletions src/agenteval/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .imports import import_class, validate_subclass

__all__ = ["import_class", "validate_subclass"]
14 changes: 14 additions & 0 deletions src/agenteval/utils/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from importlib import import_module


def import_class(module_path: str) -> type:
name, class_name = module_path.rsplit(".", 1)

return getattr(import_module(name), class_name)


def validate_subclass(child_class: type, parent_class: type) -> None:
if not issubclass(child_class, parent_class):
raise TypeError(
f"{child_class.__name__} is not a {parent_class.__name__} subclass"
)
Loading

0 comments on commit 321c97e

Please sign in to comment.