diff --git a/docs/configuration.md b/docs/configuration.md index 295ac7c..7bfe892 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -9,6 +9,7 @@ tests: expected_results: initial_prompt: max_turns: + hook: ``` `evaluator` _(map)_ @@ -32,5 +33,6 @@ A list of tests. Each test maps with the following fields: - `expected_results` _(list of strings)_: The expected results for the test. - `initial_prompt` _(string; optional)_: The first message that is sent to the agent, which starts the conversation. If unspecified, the message will be generated based on the `steps` provided. - `max_turns` _(integer; optional)_: The maximum number of user-agent exchanges before the test fails. The default is `2`. +- `hook` _(string; optional)_: The module path to an evaluation hook. --- diff --git a/docs/reference/hook.md b/docs/reference/hook.md new file mode 100644 index 0000000..6eefbbb --- /dev/null +++ b/docs/reference/hook.md @@ -0,0 +1 @@ +::: src.agenteval.hook \ No newline at end of file diff --git a/docs/targets/index.md b/docs/targets/index.md index 3a3b62e..18b6ab6 100644 --- a/docs/targets/index.md +++ b/docs/targets/index.md @@ -63,9 +63,9 @@ If you want to test an agent that is not natively supported, you can bring your return TargetResponse(response=response) ``` -Once you have created your subclass, specify the Target in `agenteval.yaml`. +Once you have created your subclass, specify the Target in `agenteval.yml`. -!!! example "agenteval.yaml" +!!! example "agenteval.yml" ```yaml target: @@ -124,7 +124,7 @@ We will create a simple [LangChain](https://python.langchain.com/docs/modules/ag Create a test plan that references `MyLangChainTarget`. -!!! example "agenteval.yaml" +!!! example "agenteval.yml" ```yaml evaluator: diff --git a/docs/user_guide.md b/docs/user_guide.md index ba918d9..36e9bb6 100644 --- a/docs/user_guide.md +++ b/docs/user_guide.md @@ -6,7 +6,7 @@ To begin, initialize a test plan. agenteval init ``` -This will create an `agenteval.yaml` file in the current directory. +This will create an `agenteval.yml` file in the current directory. ```yaml evaluator: @@ -99,3 +99,108 @@ tests: - The agent returns the details on claim-006. initial_prompt: Can you let me know which claims are open? ``` + +## Evaluation hooks +You can specify hooks that run before and/or after evaluating a test. This is useful for performing integration testing, as well as any setup or cleanup tasks required. + +To create hooks, you will define a subclass of [Hook](reference/hook.md#src.agenteval.hook.Hook). + +For a hook that runs *before* evaluation, implement a `pre_evaluate` method. In this method, you have access to the [Test](reference/test.md#src.agenteval.test.Test) and [TraceHandler](reference/trace_handler.md#src.agenteval.trace_handler.TraceHandler) via the `test` and `trace` arguments, respectively. + +For a hook that runs *after* evaluation, implement a `post_evaluate` method. Similar to the `pre_evaluate` method, you have access to the `Test` and `TraceHandler`. You also have access to the [TestResult](reference/test_result.md#src.agenteval.test_result.TestResult) via the `test_result` argument. You may override the attributes of the `TestResult` if you plan to use this hook to perform additional testing, such as integration testing. + +!!! example "my_evaluation_hook.py" + + ```python + from agenteval import Hook + + class MyEvaluationHook(Hook): + + def pre_evaluate(test, trace): + # implement logic here + + def post_evaluate(test, test_result, trace): + # implement logic here + ``` + +Once you have created your subclass, specify the hook for the test in `agenteval.yml`. + +!!! example "agenteval.yml" + + ```yaml + tests: + - name: MakeReservation + hook: my_evaluation_hook.MyEvaluationHook + ``` + +### Examples + +#### Integration testing using `post_evaluate` + +In this example, we will test an agent that can make dinner reservations. In addition to evaluating the conversation, we want to test that the reservation is written to the backend database. To do this, we will create a post evaluation hook that queries a PostgreSQL database for the reservation record. If the record is not found, we will override the `TestResult`. + +!!! example "test_record_insert.py" + + ```python + import boto3 + import json + import psycopg2 + + from agenteval import Hook + + SECRET_NAME = "test-secret" + + def get_db_secret() -> dict: + + session = boto3.session.Session() + client = session.client( + service_name='secretsmanager', + ) + get_secret_value_response = client.get_secret_value( + SecretId=SECRET_NAME + ) + secret = get_secret_value_response['SecretString'] + + return json.loads(secret) + + class TestRecordInsert(Hook): + + def post_evaluate(test, test_result, trace): + + # get database secret from AWS Secrets Manager + secret = get_db_secret() + + # connect to database + conn = psycopg2.connect( + database=secret["dbname"], + user=secret["username"], + password=secret["password"], + host=secret["host"], + port=secret["port"] + ) + + # check if record is written to database + with conn.cursor() as cur: + cur.execute("SELECT * FROM reservations WHERE name = 'Bob'") + row = cur.fetchone() + + # override the test result based on query result + if not row: + test_result.success = False + test_result.result = "Integration test failed" + test_result.reasoning = "Record was not inserted into the database" + ``` + +Create a test that references the hook. + +!!! example "agenteval.yml" + + ```yaml + tests: + - name: MakeReservation + steps: + - Ask agent to make a reservation under the name Bob for 7 PM. + expected_results: + - The agent confirms that a reservation has been made. + hook: test_record_insert.TestRecordInsert + ``` \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index c847fcf..d92e6c9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -40,6 +40,7 @@ nav: - cli: reference/cli.md - conversation_handler: reference/conversation_handler.md - evaluator: reference/evaluator.md + - hook: reference/hook.md - target: reference/target.md - test: reference/test.md - test_result: reference/test_result.md diff --git a/src/agenteval/__init__.py b/src/agenteval/__init__.py index d34a7e6..d054917 100644 --- a/src/agenteval/__init__.py +++ b/src/agenteval/__init__.py @@ -6,6 +6,9 @@ from jinja2 import Environment, PackageLoader from rich.logging import RichHandler +from .hook import Hook + +__all__ = ["Hook"] _LOG_LEVEL_ENV = "LOG_LEVEL" _FORMAT = "%(message)s" diff --git a/src/agenteval/evaluators/claude/claude.py b/src/agenteval/evaluators/claude/claude.py index ca5b5b3..39948e5 100644 --- a/src/agenteval/evaluators/claude/claude.py +++ b/src/agenteval/evaluators/claude/claude.py @@ -166,7 +166,7 @@ def _invoke_target(self, user_input) -> str: return target_response.response - def run(self) -> TestResult: + def evaluate(self) -> TestResult: success = False eval_reasoning = "" result = "Max turns reached." diff --git a/src/agenteval/evaluators/evaluator.py b/src/agenteval/evaluators/evaluator.py index 05cbc8e..5ee4c00 100644 --- a/src/agenteval/evaluators/evaluator.py +++ b/src/agenteval/evaluators/evaluator.py @@ -6,10 +6,12 @@ from botocore.client import BaseClient from agenteval.conversation_handler import ConversationHandler +from agenteval.hook import Hook from agenteval.targets import BaseTarget from agenteval.test import Test from agenteval.test_result import TestResult from agenteval.trace_handler import TraceHandler +from agenteval.utils import import_class, validate_subclass _BEDROCK_RUNTIME_BOTO3_NAME = "bedrock-runtime" @@ -19,27 +21,29 @@ class BaseEvaluator(ABC): classes. Attributes: - test (Test): The test to conduct. + test (Test): The test case. target (BaseTarget): The target agent being evaluated. conversation (ConversationHandler): Conversation handler for capturing the interaction between the evaluator (user) and target (agent). trace (TraceHandler): Trace handler for capturing steps during evaluation. + test_result (TestResult): The result of the test which is set in `BaseEvaluator.run`. """ def __init__(self, test: Test, target: BaseTarget): """Initialize the evaluator instance for a given `Test` and `Target`. Args: - test (Test): The test to conduct. + test (Test): The test case. target (BaseTarget): The target agent being evaluated. """ self.test = test self.target = target self.conversation = ConversationHandler() self.trace = TraceHandler(test_name=test.name) + self.test_result = None @abstractmethod - def run(self) -> TestResult: + def evaluate(self) -> TestResult: """Conduct a test. Returns: @@ -47,13 +51,28 @@ def run(self) -> TestResult: """ pass - def trace_run(self): + def _get_hook_cls(self, hook: Optional[str]) -> Optional[type[Hook]]: + if hook: + hook_cls = import_class(hook) + validate_subclass(hook_cls, Hook) + return hook_cls + + def run(self) -> TestResult: """ - Runs the evaluator within a trace context manager to - capture trace data. + Run the evaluator within a trace context manager and run hooks + if provided. """ + + hook_cls = self._get_hook_cls(self.test.hook) + with self.trace: - return self.run() + if hook_cls: + hook_cls.pre_evaluate(self.test, self.trace) + self.test_result = self.evaluate() + if hook_cls: + hook_cls.post_evaluate(self.test, self.test_result, self.trace) + + return self.test_result class AWSEvaluator(BaseEvaluator): diff --git a/src/agenteval/hook.py b/src/agenteval/hook.py new file mode 100644 index 0000000..db86867 --- /dev/null +++ b/src/agenteval/hook.py @@ -0,0 +1,30 @@ +from agenteval.test import Test +from agenteval.test_result import TestResult +from agenteval.trace_handler import TraceHandler + + +class Hook: + """An evaluation hook.""" + + def pre_evaluate(test: Test, trace: TraceHandler) -> None: + """ + Method called before evaluation. Can be used to perform any setup tasks. + + Args: + test (Test): The test case. + trace (TraceHandler): Trace handler for capturing steps during evaluation. + """ + pass + + def post_evaluate(test: Test, test_result: TestResult, trace: TraceHandler) -> None: + """ + Method called after evaluation. This may be used to perform integration testing + or clean up tasks. + + Args: + test (Test): The test case. + test_result (TestResult): The result of the test, which can be overriden + by updating the attributes of this object. + trace (TraceHandler): Trace handler for capturing steps during evaluation. + """ + pass diff --git a/src/agenteval/plan.py b/src/agenteval/plan.py index 75ca1e7..76b63bd 100644 --- a/src/agenteval/plan.py +++ b/src/agenteval/plan.py @@ -1,10 +1,9 @@ from __future__ import annotations -import importlib import logging import os import sys -from typing import Optional, Union +from typing import Optional import yaml from pydantic import BaseModel, model_validator @@ -18,6 +17,7 @@ SageMakerEndpointTarget, ) from agenteval.test import Test +from agenteval.utils import import_class, validate_subclass _PLAN_FILE_NAME = "agenteval.yml" @@ -86,10 +86,8 @@ def create_evaluator(self, test: Test, target: BaseTarget) -> BaseEvaluator: if self.evaluator_config.type in _EVALUATOR_MAP: evaluator_cls = _EVALUATOR_MAP[self.evaluator_config.type] else: - evaluator_cls = self._get_class(self.evaluator_config) - - if not issubclass(evaluator_cls, BaseEvaluator): - raise TypeError(f"{evaluator_cls} is not a Evaluator subclass") + evaluator_cls = import_class(self.evaluator_config.type) + validate_subclass(evaluator_cls, BaseEvaluator) return evaluator_cls( test=test, @@ -103,23 +101,11 @@ def create_target( if self.target_config.type in _TARGET_MAP: target_cls = _TARGET_MAP[self.target_config.type] else: - target_cls = self._get_class(self.target_config) - - if not issubclass(target_cls, BaseTarget): - raise TypeError(f"{target_cls} is not a Target subclass") + target_cls = import_class(self.target_config.type) + validate_subclass(target_cls, BaseTarget) return target_cls(**self.target_config.model_dump(exclude="type")) - def _get_class(self, config: Union[EvaluatorConfig, TargetConfig]) -> type: - module_path, class_name = config.type.rsplit(".", 1) - return self._import_class(module_path, class_name) - - @staticmethod - def _import_class(module_path: str, class_name: str) -> type[BaseTarget]: - module = importlib.import_module(module_path) - target_cls = getattr(module, class_name) - return target_cls - @staticmethod def _load_yaml(path: str) -> dict: with open(path) as stream: @@ -136,6 +122,7 @@ def _load_tests(test_config: list[dict]) -> list[Test]: expected_results=t.get("expected_results"), initial_prompt=t.get("initial_prompt"), max_turns=t.get("max_turns", defaults.MAX_TURNS), + hook=t.get("hook"), ) ) return tests diff --git a/src/agenteval/runner/runner.py b/src/agenteval/runner/runner.py index 0262274..1865cb8 100644 --- a/src/agenteval/runner/runner.py +++ b/src/agenteval/runner/runner.py @@ -63,7 +63,7 @@ def run_test(self, test): target=target, ) - result = evaluator.trace_run() + result = evaluator.run() if result.success is False: self.num_failed += 1 diff --git a/src/agenteval/test.py b/src/agenteval/test.py index 8b1eb31..f6de803 100644 --- a/src/agenteval/test.py +++ b/src/agenteval/test.py @@ -7,11 +7,12 @@ class Test(BaseModel, validate_assignment=True): """A test case for an agent. Attributes: - name: Name of the test - steps: List of step to perform for the test - expected_results: List of expected results for the test - initial_prompt: Optional initial prompt - max_turns: Maximum number of turns allowed for the test + name: Name of the test. + steps: List of step to perform for the test. + expected_results: List of expected results for the test. + initial_prompt: Optional initial prompt. + max_turns: Maximum number of turns allowed for the test. + hook: The module path to an evaluation hook. """ # do not collect as a test @@ -22,3 +23,4 @@ class Test(BaseModel, validate_assignment=True): expected_results: list[str] initial_prompt: Optional[str] = None max_turns: int + hook: Optional[str] = None diff --git a/src/agenteval/utils/__init__.py b/src/agenteval/utils/__init__.py new file mode 100644 index 0000000..038b2dc --- /dev/null +++ b/src/agenteval/utils/__init__.py @@ -0,0 +1,3 @@ +from .imports import import_class, validate_subclass + +__all__ = ["import_class", "validate_subclass"] diff --git a/src/agenteval/utils/imports.py b/src/agenteval/utils/imports.py new file mode 100644 index 0000000..4b92503 --- /dev/null +++ b/src/agenteval/utils/imports.py @@ -0,0 +1,14 @@ +from importlib import import_module + + +def import_class(module_path: str) -> type: + name, class_name = module_path.rsplit(".", 1) + + return getattr(import_module(name), class_name) + + +def validate_subclass(child_class: type, parent_class: type) -> None: + if not issubclass(child_class, parent_class): + raise TypeError( + f"{child_class.__name__} is not a {parent_class.__name__} subclass" + ) diff --git a/tests/src/agenteval/evaluators/claude/test_claude.py b/tests/src/agenteval/evaluators/claude/test_claude.py index 0dd1f62..95eca47 100644 --- a/tests/src/agenteval/evaluators/claude/test_claude.py +++ b/tests/src/agenteval/evaluators/claude/test_claude.py @@ -112,7 +112,7 @@ def test_run_single_turn_success(self, mocker, evaluator_fixture): "", ) - result = evaluator_fixture.run() + result = evaluator_fixture.evaluate() assert result.success is True @@ -140,7 +140,7 @@ def test_run_single_turn_initial_prompt_success(self, mocker, evaluator_fixture) "", ) - result = evaluator_fixture.run() + result = evaluator_fixture.evaluate() assert result.success is True assert mock_generate_initial_prompt.call_count == 1 @@ -170,7 +170,7 @@ def test_run_multi_turn_success(self, mocker, evaluator_fixture): ) mock_generate_user_response.return_value = "test user response" - result = evaluator_fixture.run() + result = evaluator_fixture.evaluate() assert result.success is True assert mock_generate_task_status.call_count == 2 @@ -193,7 +193,7 @@ def test_run_max_turns_exceeded(self, mocker, evaluator_fixture): ) mock_generate_user_response.return_value = "test user response" - result = evaluator_fixture.run() + result = evaluator_fixture.evaluate() assert result.success is False assert mock_generate_user_response.call_count == 1 diff --git a/tests/src/agenteval/evaluators/test_evaluator.py b/tests/src/agenteval/evaluators/test_evaluator.py index 7ff49aa..8b50495 100644 --- a/tests/src/agenteval/evaluators/test_evaluator.py +++ b/tests/src/agenteval/evaluators/test_evaluator.py @@ -3,7 +3,7 @@ class DummyAWSEvaluator(evaluator.AWSEvaluator): - def run(self): + def evaluate(self): pass diff --git a/tests/src/agenteval/test_plan.py b/tests/src/agenteval/test_plan.py index 2b7739a..bbb7248 100644 --- a/tests/src/agenteval/test_plan.py +++ b/tests/src/agenteval/test_plan.py @@ -57,40 +57,25 @@ def test_create_target_aws_type(self, mocker, plan_fixture): assert isinstance(target_cls, plan.BedrockAgentTarget) def test_create_target_custom_type(self, mocker, plan_with_custom_target_fixture): - mock_get_class = mocker.patch.object( - plan_with_custom_target_fixture, "_get_class" - ) + mock_import_class = mocker.patch("src.agenteval.plan.import_class") - mock_get_class.return_value = CustomTarget + mock_import_class.return_value = CustomTarget target = plan_with_custom_target_fixture.create_target() assert isinstance(target, CustomTarget) def test_create_target_not_subclass(self, mocker, plan_with_custom_target_fixture): - mock_get_class = mocker.patch.object( - plan_with_custom_target_fixture, "_get_class" - ) + mock_import_class = mocker.patch("src.agenteval.plan.import_class") class InvalidTarget: pass - mock_get_class.return_value = InvalidTarget + mock_import_class.return_value = InvalidTarget with pytest.raises(TypeError): plan_with_custom_target_fixture.create_target() - def test_get_class(self, mocker, plan_with_custom_target_fixture): - mock_import_class = mocker.patch.object( - plan_with_custom_target_fixture, "_import_class" - ) - mock_import_class.return_value = CustomTarget - - target_cls = plan_with_custom_target_fixture._get_class( - config=plan_with_custom_target_fixture.target_config, - ) - assert target_cls == CustomTarget - def test_load_tests(self): tests = plan.Plan._load_tests( test_config=[ diff --git a/tests/src/agenteval/test_trace_handler.py b/tests/src/agenteval/test_trace_handler.py index 60d5b2a..f00bc3b 100644 --- a/tests/src/agenteval/test_trace_handler.py +++ b/tests/src/agenteval/test_trace_handler.py @@ -12,6 +12,20 @@ class TestTraceHandler: def test_init(self, trace_handler_fixture): assert trace_handler_fixture.steps == [] assert trace_handler_fixture.test_name == "my_test" + assert trace_handler_fixture.start_time is None + assert trace_handler_fixture.end_time is None + + def test_enter(self, mocker, trace_handler_fixture): + + mock_dump_trace = mocker.patch.object(trace_handler_fixture, "_dump_trace") + + with trace_handler_fixture: + pass + + assert isinstance(trace_handler_fixture.start_time, trace_handler.datetime) + assert isinstance(trace_handler_fixture.end_time, trace_handler.datetime) + + mock_dump_trace.assert_called_once() def test_add_step(self, trace_handler_fixture): diff --git a/tests/src/agenteval/utils/test_module_import.py b/tests/src/agenteval/utils/test_module_import.py new file mode 100644 index 0000000..9d04468 --- /dev/null +++ b/tests/src/agenteval/utils/test_module_import.py @@ -0,0 +1,37 @@ +from src.agenteval.utils import imports +import pytest + + +def test_import_class(mocker): + + mock_import_module = mocker.patch("src.agenteval.utils.imports.import_module") + + test_class = imports.import_class("path.to.module.TestClass") + + mock_import_module.assert_called_once_with("path.to.module") + assert test_class == mock_import_module.return_value.TestClass + + +def test_validate_subclass(): + + class Parent: + pass + + class Child(Parent): + pass + + result = imports.validate_subclass(Child, Parent) + + assert result is None + + +def test_validate_subclass_error(): + + class Parent: + pass + + class Child: + pass + + with pytest.raises(TypeError): + imports.validate_subclass(Child, Parent)